|   1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
 | class Preprocessor(object):
    # initialize the dictionary {token: id}, and {id: token},  
    def __init__(self, dataset):
        root_labels = list([l for ex in dataset
                           for (h, l) in zip(ex['head'], ex['label']) if h == 0]) #h==0 is root
        counter = Counter(root_labels)
        if len(counter) > 1:
            logging.info('Warning: more than one root label')
            logging.info(counter)
        self.root_label = counter.most_common()[0][0]
        deprel = [self.root_label] + list(set([w for ex in dataset
                                               for w in ex['label']
                                               if w != self.root_label]))
        '''
        the deprel (dependent relations) are:
        
        '''
        
        tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)}
        tok2id[L_PREFIX + NULL] = self.L_NULL = len(tok2id)
        config = Config()
        self.unlabeled = config.unlabeled
        self.with_punct = config.with_punct
        self.use_pos = config.use_pos
        self.use_dep = config.use_dep
        self.language = config.language
        if self.unlabeled: #just S, LA, RA
            trans = ['L', 'R', 'S']
            self.n_deprel = 1
        else: #S_NN
            trans = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['S']
            self.n_deprel = len(deprel)
        self.n_trans = len(trans)
        self.tran2id = {t: i for (i, t) in enumerate(trans)}
        self.id2tran = {i: t for (i, t) in enumerate(trans)}
        
        '''
        example of id2tran is:
        {0: 'L', 1: 'R', 2: 'S'}
        '''
        # logging.info('Build dictionary for part-of-speech tags.')
        tok2id.update(build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']],
                                  offset=len(tok2id)))
        tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id)
        tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id)
        tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id)
        # logging.info('Build dictionary for words.')
        tok2id.update(build_dict([w for ex in dataset for w in ex['word']],
                                  offset=len(tok2id)))
        tok2id[UNK] = self.UNK = len(tok2id)
        tok2id[NULL] = self.NULL = len(tok2id)
        tok2id[ROOT] = self.ROOT = len(tok2id)
        
        '''
        example of id2tok is:
        {0: '<l>:root', 1: '<l>:nsubj', 2: '<l>:advmod', 3: '<l>:dobj', 4: '<l>:acl', 5: '<l>:nummod', 6: '<l>:nmod', 7: '<l>:compound', 8: '<l>:xcomp', 
        9: '<l>:auxpass', 10: '<l>:det', 11: '<l>:punct', 12: '<l>:mark', 13: '<l>:case', 14: '<l>:conj', 15: '<l>:appos', 16: '<l>:nmod:tmod', 
        17: '<l>:dep', 18: '<l>:amod', 19: '<l>:nmod:poss', 20: '<l>:nsubjpass', 21: '<l>:cc', 22: '<l>:ccomp', 23: '<l>:<NULL>', 24: '<p>:NNP', 
        25: '<p>:NN', 26: '<p>:IN', 27: '<p>:DT', 28: '<p>:,', 29: '<p>:NNS', 30: '<p>:JJ', 31: '<p>:CD', 32: '<p>:CC', 33: '<p>:VBD', 34: '<p>:.', 
        35: '<p>:VBN', 36: '<p>:VBZ', 37: '<p>:``', 38: "<p>:''", 39: '<p>:VB', 40: '<p>:TO', 41: '<p>:PRP', 42: '<p>:POS', 43: '<p>:-LRB-', 
        44: '<p>:-RRB-', 45: '<p>:RB', 46: '<p>:NNPS', 47: '<p>:PRP$', 48: '<p>:<UNK>', 49: '<p>:<NULL>', 50: '<p>:<ROOT>', 51: ',', 52: 'in', 53: 'the', 
        54: '.', 55: 'cars', 56: 'and', 57: 'of', 58: '``', 59: "''", 60: 'at', 61: 'to', 62: 'haag', 63: 'said', 64: 'u.s.', 65: 'luxury', 66: 'auto', 
        67: 'maker', 68: 'an', 69: 'oct.', 70: '19', 71: 'review', 72: 'misanthrope', 73: 'chicago', 74: "'s", 75: 'goodman', 76: 'theatre', 77: '-lrb-', 
        78: 'revitalized', 79: 'classics', 80: 'take', 81: 'stage', 82: 'windy', 83: 'city', 84: 'leisure', 85: '&', 86: 'arts', 87: '-rrb-', 88: 'role',
        89: 'celimene', 90: 'played', 91: 'by', 92: 'kim', 93: 'cattrall', 94: 'was', 95: 'mistakenly', 96: 'attributed', 97: 'christina', 98: 'ms.', 
        99: 'plays', 100: 'elianti', 101: 'rolls-royce', 102: 'motor', 103: 'inc.', 104: 'it', 105: 'expects', 106: 'its', 107: 'sales', 108: 'remain', 
        109: 'steady', 110: 'about', 111: '1,200', 112: '1990', 113: 'last', 114: 'year', 115: 'sold', 116: '1,214', 117: 'howard', 118: 'mosher', 
        119: 'president', 120: 'chief', 121: 'executive', 122: 'officer', 123: 'he', 124: 'anticipates', 125: 'growth', 126: 'for', 127: 'britain', 
        128: 'europe', 129: 'far', 130: 'eastern', 131: 'markets', 132: '<UNK>', 133: '<NULL>', 134: '<ROOT>'}
        '''
        self.tok2id = tok2id
        self.id2tok = {v: k for (k, v) in tok2id.items()}
        self.n_features = 18 + (18 if config.use_pos else 0) + (12 if config.use_dep else 0) #? why 18?
        self.n_tokens = len(tok2id) 
    
    '''
    arranging each set as this form: [{'word': [corresponding id of each token], 'pos': [corresponding id in pos from tok2id], 'head': [id...], 'label': [id...]},...,      #{last sentence in the set}]
        The resulting vec_examples are (this is vectorized training data in id form from tok2id):
        [{'word': [134, 52, 68, 69, 70, 71, 57, 58, 53, 72, 59, 60, 73, 74, 75, 76, 77, 58, 78, 79, 80, 53, 81, 52, 82, 83, 51, 59, 84, 85, 86, 87, 51, 53, 88, 57, 89, 51, 90, 91, 92, 93, 51, 94, 95, 96, 61, 97, 62, 54], 
        'pos': [50, 26, 27, 24, 31, 25, 26, 37, 27, 25, 38, 26, 24, 42, 24, 24, 43, 37, 35, 29, 39, 27, 25, 26, 24, 24, 28, 38, 25, 32, 29, 44, 28, 27, 25, 26, 24, 28, 35, 26, 24, 24, 28, 33, 45, 35, 40, 24, 24, 34], 
        'head': [-1, 5, 5, 5, 5, 45, 9, 9, 9, 5, 9, 15, 15, 12, 15, 9, 20, 20, 19, 20, 5, 22, 20, 25, 25, 20, 20, 20, 20, 28, 28, 20, 45, 34, 45, 36, 34, 34, 34, 41, 41, 38, 34, 45, 45, 0, 48, 48, 45, 45], 
        'label': [-1, 13, 10, 7, 5, 6, 13, 11, 10, 6, 11, 13, 19, 13, 7, 6, 11, 11, 18, 1, 17, 10, 3, 13, 7, 6, 11, 11, 17, 21, 14, 11, 11, 10, 20, 13, 6, 11, 4, 13, 7, 6, 11, 9, 2, 0, 13, 7, 6, 11]}, 
        {'word': [134, 98, 62, 99, 100, 54], 'pos': [50, 24, 24, 36, 24, 34], 'head': [-1, 2, 3, 0, 3, 3], 'label': [-1, 7, 1, 0, 3, 11]},{...}...{...}
    '''
    def vectorize(self, examples): 
        vec_examples = []
        for ex in examples:
            word = [self.ROOT] + [self.tok2id[w] if w in self.tok2id
                                  else self.UNK for w in ex['word']] # a list of word id for each sentence in dataset
            pos = [self.P_ROOT] + [self.tok2id[P_PREFIX + w] if P_PREFIX + w in self.tok2id # a list of POS id
                                   else self.P_UNK for w in ex['pos']]
            head = [-1] + ex['head']
            label = [-1] + [self.tok2id[L_PREFIX + w] if L_PREFIX + w in self.tok2id
                            else -1 for w in ex['label']]
            vec_examples.append({'word': word, 'pos': pos,
                                 'head': head, 'label': label})
        return vec_examples # each sentence is a dictionary with list of 'word','POS', 'head' and 'label
    
    '''
    This function is important, it extracts context words first, then extract the corresping word from ex['word'], pos from ex['pos']
    and label features from ex['label'], finally concatenate them into a long vector for each sentence, an example of one set of features for one sentence is:
    [133, 133, 134, 52, 68, 69, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 49, 49, 50, 26, 27, 24, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49] #36 features of context words with corresponding id from tok2id,  that is a long list containing [features, p_features, l_features] 
    '''
    def extract_features(self, stack, buf, arcs, ex):
        if stack[0] == "ROOT":
            stack[0] = 0
        def get_lc(k): # lc: left context, since arc = (stack[-1], stack[-2], gold_t) or (stack[-2], stack[-1], gold_t), if arc[1]'s id < arc[0]'s id, then extract lc
            return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] < k])
        def get_rc(k): #if arc[1]'s id > arc[0]'s id, rc
            return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] > k],
                          reverse=True)
        p_features = []
        l_features = []
        features = [self.NULL] * (3 - len(stack)) + [ex['word'][x] for x in stack[-3:]] #if len(stack)>=3, no NULL, features contain the last 3 words in the stack
        features += [ex['word'][x] for x in buf[:3]] + [self.NULL] * (3 - len(buf)) # extract the first 3 words in the buffer if len(buf) >=3
        if self.use_pos: # extracting the corresponding pos features, 3 from stack, 3 from buffer if stack and buffer have length>=3
            p_features = [self.P_NULL] * (3 - len(stack)) + [ex['pos'][x] for x in stack[-3:]]
            p_features += [ex['pos'][x] for x in buf[:3]] + [self.P_NULL] * (3 - len(buf))
        for i in range(2): #i = 0, 1
            if i < len(stack):
                k = stack[-i-1] #the kth word's context: i=0, the -1 word's context, i=1, the -2 word's context
                lc = get_lc(k) # the left context word of kth word in stack = [(s3,) s2, s1]
                rc = get_rc(k) # the right context word
                llc = get_lc(lc[0]) if len(lc) > 0 else [] #llc
                rrc = get_rc(rc[0]) if len(rc) > 0 else [] #rrc
                features.append(ex['word'][lc[0]] if len(lc) > 0 else self.NULL)
                features.append(ex['word'][rc[0]] if len(rc) > 0 else self.NULL)
                features.append(ex['word'][lc[1]] if len(lc) > 1 else self.NULL)
                features.append(ex['word'][rc[1]] if len(rc) > 1 else self.NULL)
                features.append(ex['word'][llc[0]] if len(llc) > 0 else self.NULL)
                features.append(ex['word'][rrc[0]] if len(rrc) > 0 else self.NULL)
                if self.use_pos:
                    p_features.append(ex['pos'][lc[0]] if len(lc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][rc[0]] if len(rc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][lc[1]] if len(lc) > 1 else self.P_NULL)
                    p_features.append(ex['pos'][rc[1]] if len(rc) > 1 else self.P_NULL)
                    p_features.append(ex['pos'][llc[0]] if len(llc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][rrc[0]] if len(rrc) > 0 else self.P_NULL)
                if self.use_dep:
                    l_features.append(ex['label'][lc[0]] if len(lc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][rc[0]] if len(rc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][lc[1]] if len(lc) > 1 else self.L_NULL)
                    l_features.append(ex['label'][rc[1]] if len(rc) > 1 else self.L_NULL)
                    l_features.append(ex['label'][llc[0]] if len(llc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][rrc[0]] if len(rrc) > 0 else self.L_NULL)
            else:
                features += [self.NULL] * 6
                if self.use_pos:
                    p_features += [self.P_NULL] * 6
                if self.use_dep:
                    l_features += [self.L_NULL] * 6
        features += p_features + l_features
        assert len(features) == self.n_features
        return features 
    '''
    This is to get the correct transitions for training purposes. 
    1. if there is only ROOT in stack, the only correct transition is shift
    2. if stack = [i1, i0] or more, i2 is not ROOT, i1's head == i0, then left arc is the correct transition
    3. if stack = [i1, i0] or more, i0's head == i1, i1 can be ROOT or not ROOT, right arc is the correct transition
    4. ow shift
    '''
    def get_oracle(self, stack, buf, ex):
        if len(stack) < 2:
            return self.n_trans - 1 #如果stack只有root, 就执行shift, shift对应id是48 or 2: n_tran = 3-1 if unlabel or 49-1  if label1
        i0 = stack[-1]
        i1 = stack[-2]
        h0 = ex['head'][i0]
        h1 = ex['head'][i1]
        l0 = ex['label'][i0]
        l1 = ex['label'][i1]
        if self.unlabeled:
            if (i1 > 0) and (h1 == i0): #if i1 is not ROOT, stack[i1, i0], head is i0, left_arc i1
                return 0
            elif (i1 >= 0) and (h0 == i1) and (not any([x for x in buf if ex['head'][x] == i0])): 
                return 1 # if head is i1, no i0 head in buffer , right_arc
            else:
                return None if len(buf) == 0 else 2 #if len(buf) =0, no transition, ow, shift
        else:
            if (i1 > 0) and (h1 == i0):
                return l1 if (l1 >= 0) and (l1 < self.n_deprel) else None
            elif (i1 >= 0) and (h0 == i1) and \
                 (not any([x for x in buf if ex['head'][x] == i0])):
                return l0 + self.n_deprel if (l0 >= 0) and (l0 < self.n_deprel) else None
            else:
                return None if len(buf) == 0 else self.n_trans - 1
    
    '''
    for each sentence in training data, create stack with only ROOT in id form [0], buf with all tokens of the sentence, and empty arcs, get oracle for the 
    stack[ROOT] and full buffer (oracle should be shift for this case), get legal labels (only shift is legal for stack[ROOT] and full buffer), so in this
    round the legal_labels will return [0,0,1], then update stack, buffer and arcs, repeate this procedure for the updated stack, buffer and arcs for 
    2*#ofwords times. for each update, instances are accumulated as walking through the whole sentence as [], each sentence is an instances, all_instances  =
    [[instances1], [instances2], ...[]]
    '''
    def create_instances(self, examples):
        all_instances = []
        succ = 0
        for n_ex, ex in enumerate(examples):
            n_words = len(ex['word']) - 1
            # arcs = {(h, t, label)}
            stack = [0]
            buf = [i + 1 for i in range(n_words)]
            arcs = []
            instances = []
            for i in range(n_words * 2): # rolling over each word to get the features (legal label and gold_t) for each transition in (llc, lc, w, rc, rrc) for each sentence, the transitions have 2*n_word times 
                gold_t = self.get_oracle(stack, buf, ex)
                if gold_t is None:
                    break
                legal_labels = self.legal_labels(stack, buf)
                assert legal_labels[gold_t] == 1 # correct transition must belong to legal lables
                instances.append((self.extract_features(stack, buf, arcs, ex),
                                  legal_labels, gold_t)) 
                if gold_t == self.n_trans - 1: # gold_t is shift= 2
                    stack.append(buf[0])
                    buf = buf[1:]
                elif gold_t < self.n_deprel: #n_deprel = 1, when gold_t ==0,
                    arcs.append((stack[-1], stack[-2], gold_t)) # i.e. L_arc, next gold_t must be 2, i.e. shift, stack is all the words except the -2 word, then next loop will be adding the first word in buffer to stack (execute the first if).
                    stack = stack[:-2] + [stack[-1]]
                else:
                    arcs.append((stack[-2], stack[-1], gold_t - self.n_deprel))
                    stack = stack[:-1] #when gold_t ==1, R_arc, next gold_t must be 0, i.e. shift, stack will be all the words except the last word, then next loop will be adding the first word in buffer to stack (execute the first if).
            succ += 1
            all_instances += instances # one instance is one sentence with 2*n_word tuples, each tuple contains context features, legal_labels, gold_t
        return all_instances
    def legal_labels(self, stack, buf):
        labels = ([1] if len(stack) > 2 else [0]) * self.n_deprel #stack<=2 must not be L-arc, stack>2, L-arc is legal
        labels += ([1] if len(stack) >= 2 else [0]) * self.n_deprel
        labels += [1] if len(buf) > 0 else [0] #buf > 0,  shift is legal 
        return labels # labels' length == n_tran, 3 or 49
 |