import random import numpy as np from collections import defaultdict as dd, defaultdict from itertools import count import re import time import math import _dynet as dy
dyparams = dy.DynetParams() dyparams.from_args() dyparams.set_requested_gpus(1) dyparams.set_mem(2048) dyparams.set_random_seed(792321264) dyparams.init()
DEBUG = True
train_string_file = "data/train.strings" train_tree_file = "data/train.trees.pre.unk" dev_string_file = "data/dev.strings" dev_tree_file = "data/dev.trees" dev_parse_file = "data/dev.parses" if DEBUG: train_string_file = "data/train_small.strings" train_tree_file = "data/train_small.trees.pre.unk" dev_string_file = "data/dev_small.strings" dev_tree_file = "data/dev_small.trees" dev_parse_file = "data/dev_small.parses"
train_string = [] train_tree = [] words = [] with open(train_string_file, "r") as fh: for line in fh: train_string.append(line) for word in line.split(): words.append(word) words.append("<unk>")
with open(train_tree_file, "r") as fh: for line in fh: train_tree.append(line)
w2i = defaultdict(count(0).next) for word in words: w2i[word] i2w = {i:w for w, i in w2i.iteritems()} nwords = len(w2i)
nonTerms = set() rules_set1 = set() rules_set2 = set() rules = {} lexicons = [] origText = list() probs = defaultdict(float) node_pa = {}
def read_grammar(f): grammar = {} file = open(f, 'r') for rule in file: tokens = re.split(r"\-\>|\@", rule.strip()) lhs = tokens[0].strip() rhs = tokens[1].strip().strip(r'\'') rhs = rhs.strip(r'\"') prob = tokens[2].strip() probs[(lhs, rhs)] = float(prob) nonTerms.add(lhs) if len(rhs.split()) == 1: rules_set1.add((lhs, rhs)) else: rules_set2.add((lhs, rhs)) if rhs in node_pa: node_pa[rhs].add(lhs) else: node_pa[rhs] = set() node_pa[rhs].add(lhs) rules[lhs] = rhs if len(rhs.split()) == 1 and rhs != '<unk>': lexicons.append(rhs)
if DEBUG: grammar = read_grammar('data/pcfg_small') else: grammar = read_grammar('data/pcfg') print rules_set1.__len__(), rules_set2.__len__()
EPOCH = 40 EMBDDING_SIZE = 512 lamda = 100 k = 0.1 model = dy.ParameterCollection() builder = dy.FastLSTMBuilder(2, EMBDDING_SIZE, EMBDDING_SIZE, model) trainer = dy.AdamTrainer(model) WORDS_LOOKUP = model.add_lookup_parameters((nwords, EMBDDING_SIZE)) pd = model.add_parameters((1, EMBDDING_SIZE)) pW = model.add_parameters((EMBDDING_SIZE, EMBDDING_SIZE)) pb = model.add_parameters((EMBDDING_SIZE,))
class MTree(object): def __init__(self, lhs, wrd=None, subs=None): self.label = lhs self.word = wrd self.subs = subs self.str = None
def is_lexicon(self): return self.word is not None
def dostr(self): return "(%s %s)" % (self.label, self.word) if self.is_lexicon() \ else "(%s %s)" % (self.label, " ".join(map(str, self.subs)))
def __str__(self): if True or self.str is None: self.str = self.dostr() return self.str
def helper(next, text, backPointers, terminals, score): begin = next[0] end = next[1] A = next[2] if next not in backPointers: if next in terminals: word = origText[next[0]] node = MTree(lhs=A, subs=None, wrd=word) return (node, score[(begin, end, A)]) (split, B, C) = backPointers[next] next1 = (begin, split, B) next2 = (split, end, C) t1, s1 = helper(next1, text, backPointers, terminals, score) t2, s2 = helper(next2, text, backPointers, terminals, score) return (MTree(lhs=A, subs=[t1, t2], wrd=None), score[(begin, end, A)])
def backtrack(text, backPointers, terminals, score): n = len(text) if (0, n, 'S') not in backPointers: return (None, 0) t, s = helper((0, n, 'S'), text, backPointers, terminals, score) return (t, s)
def math_log(x): if x <= 0: return -100 else: return math.log(x)
def score_calc(d, W, p, b, lamda, s_pcfg): return d * (W * p + b) * lamda + s_pcfg
def cal_loss(result, gold): if result == None: return dy.inputTensor(list([len(gold)])) result = result.split() gold = gold.split() cnt = dy.inputTensor(list([0])) for i in range(0, len(result)): if result[i] != gold[i]: cnt += 1 return cnt
def cal_gold(gold, d, W, b): words = gold.split() n = len(words) if n == 2: A = words[0][1:] word = words[1][:-1] LSTM = builder.initial_state() TMP = LSTM.add_input(WORDS_LOOKUP[w2i[word]]) e = TMP.output() s_pcfg = math_log(probs[(A, word)]) s = score_calc(d, W, e, b, lamda, probs[(A, word)]) return (e, s, s_pcfg, TMP, A) else: sz = len(gold) p = 0 for i in xrange(0, sz): if gold[i] == ' ': p = i break m = 0 cnt = 0 for i in xrange(p+1, sz): if gold[i] == '(': cnt += 1 elif gold[i] == ')': cnt -= 1 if cnt == 0: m = i break x1, s1, s1_pcfg, LSTM1, B = cal_gold(gold[p+1 : m+1], d, W, b) x2, s2, s2_pcfg, LSTM2, C = cal_gold(gold[m+2 : sz-1], d, W, b) A = gold[1:p] TMP = LSTM2.add_input(x1) e = TMP.output() s_pcfg = math_log(probs[(A, B+" "+C)]) + s1_pcfg + s2_pcfg ss1 = score_calc(d, W, e, b, lamda, s_pcfg) return (e, ss1, s_pcfg, TMP, A)
total_time = 0.0
ff = open("loss.txt", "w") for epoch in xrange(0, EPOCH): print "epoch %d" % epoch sumloss = 0 num = len(train_string) batch = [] start = time.time() for idx, line in enumerate(train_string): sstart = time.time() gold = train_tree[idx].strip() sent = line.split() origText = list(sent) n = len(sent) d = pd.expr() W = pW.expr() b = pb.expr() terminals = {} embdding = {} score = defaultdict(float) score_pcfg = defaultdict(float) backPointers = {} LSTM = {} node_rules = {} for i in range(0, n): begin = i end = i + 1 node_rules[(begin, end)] = set() word = sent[i] for A in nonTerms: if (A, word) in rules_set1: LSTM[(begin, end, A)] = builder.initial_state() LSTM[(begin, end, A)] = LSTM[(begin, end, A)].add_input(WORDS_LOOKUP[w2i[sent[i]]]) embdding[(begin, end, A)] = LSTM[(begin, end, A)].output() score_pcfg[(begin, end, A)] = math_log(probs[(A, word)]) score[(begin, end, A)] = score_calc(d, W, embdding[(begin, end, A)], b, lamda, probs[(A, word)]) terminals[(begin, end, A)] = word node_rules[(begin, end)].add(A) for span in range(2, n + 1): for begin in range(0, n - span + 1): end = begin + span node_rules[(begin, end)] = set() for split in range(begin + 1, end): for B in node_rules[(begin, split)]: for C in node_rules[(split, end)]: X = B+" "+C if X in node_pa: for A in node_pa[X]: node_rules[(begin, end)].add(A) TMP = LSTM[(split, end, C)].add_input(embdding[(begin, split, B)]) p = TMP.output() s_pcfg = math_log(probs[(A, X)]) + score_pcfg[(begin, split, B)] + score_pcfg[(split, end, C)] s = score_calc(d, W, p, b, lamda, s_pcfg) if (begin, end, A) not in score or s.value() > score[(begin, end, A)].value(): LSTM[(begin, end, A)] = TMP score[(begin, end, A)] = s score_pcfg[(begin, end, A)] = s_pcfg embdding[(begin, end, A)] = p backPointers[(begin, end, A)] = (split, B, C)
t, s = backtrack(sent, backPointers, terminals, score) result = None if t != None: result = t.dostr() golds_e, golds, golds_pcfg, lstm, S = cal_gold(gold, d, W, b) cnt = cal_loss(result, gold) loss = dy.abs(s - golds) + cnt * k + 0.5 * (dy.l2_norm(W) + dy.l2_norm(b) + dy.l2_norm(d)) sumloss += loss.value() batch.append(loss) if len(batch) == 50: loss = dy.esum(batch) loss.backward() trainer.update() dy.renew_cg() batch = [] eend = time.time() if idx > 0 and idx % 500 == 0: print "time of 500 sent: ", (eend - start) / (idx / 500) end = time.time() total_time += end - start print "epoch time: ", end - start print "epoch loss: ", sumloss / num ff.write('%f\n'%(sumloss / num))
print "total time: ", total_time
fh = open(dev_string_file, "r") outfile = open(dev_parse_file, "w") for line in fh: sent = line.split() origText = list(sent) for i, word in enumerate(sent): if word not in lexicons: sent[i] = '<unk>' n = len(sent) dy.renew_cg() d = pd.expr() W = pW.expr() b = pb.expr() terminals = {} embdding = {} score = defaultdict(float) score_pcfg = defaultdict(float) backPointers = {} LSTM = {} node_rules = {} for i in range(0, n): begin = i end = i + 1 node_rules[(begin, end)] = set() word = sent[i] for A in nonTerms: if (A, word) in rules_set1: LSTM[(begin, end, A)] = builder.initial_state() LSTM[(begin, end, A)] = LSTM[(begin, end, A)].add_input(WORDS_LOOKUP[w2i[sent[i]]]) embdding[(begin, end, A)] = LSTM[(begin, end, A)].output() score_pcfg[(begin, end, A)] = math_log(probs[(A, word)]) score[(begin, end, A)] = score_calc(d, W, embdding[(begin, end, A)], b, lamda, probs[(A, word)]) terminals[(begin, end, A)] = word node_rules[(begin, end)].add(A) for span in range(2, n + 1): for begin in range(0, n - span + 1): end = begin + span node_rules[(begin, end)] = set() for split in range(begin + 1, end): for B in node_rules[(begin, split)]: for C in node_rules[(split, end)]: X = B+" "+C if X in node_pa: for A in node_pa[X]: node_rules[(begin, end)].add(A) TMP = LSTM[(split, end, C)].add_input(embdding[(begin, split, B)]) p = TMP.output() s_pcfg = math_log(probs[(A, X)]) + score_pcfg[(begin, split, B)] + score_pcfg[(split, end, C)] s = score_calc(d, W, p, b, lamda, s_pcfg) if (begin, end, A) not in score or s.value() > score[(begin, end, A)].value(): LSTM[(begin, end, A)] = TMP score[(begin, end, A)] = s score_pcfg[(begin, end, A)] = s_pcfg embdding[(begin, end, A)] = p backPointers[(begin, end, A)] = (split, B, C)
t, s = backtrack(sent, backPointers, terminals, score) if t == None: outfile.write("None\n") else: result = t.dostr() outfile.write(result+"\n")
|