luheng / deep_srl

Code and pre-trained model for: Deep Semantic Role Labeling: What Works and What's Next
Apache License 2.0
330 stars 77 forks source link

BUG in interactive.py #7

Closed wangxinyu0922 closed 6 years ago

wangxinyu0922 commented 6 years ago

Thank you for your great work. I found some problems in interactive.py:

    s0 = string_sequence_to_ids(tokenized_sent, pid_data.word_dict, True)
    l0 = [0 for _ in s0]
    x, _, _, weights = pid_data.get_test_data([(s0, l0)], batch_size=None)
    pid_pred, scores0 = pid_pred_function(x, weights)

    s1 = []
    predicates = []
    for i,p in enumerate(pid_pred[0]):
      if pid_data.label_dict.idx2str[p] == 'V':
        #print 'Predicate:', tokenized_sent[i]
        predicates.append(i)
        feats = [1 if j == i else 0 for j in range(num_tokens)]
        s1.append((s0, feats, l0))

    if len(s1) == 0:
      continue

    # Semantic role labeling.
    x, _, _, weights = srl_data.get_test_data(s1, batch_size=None)
    srl_pred, scores = srl_pred_function(x, weights)

I think it is wrong to input s1 into the srl_data.get_test_data(), as the dictionary pid_data.word_dict and srl_data.word_dict are different, compared to predicate.py. The input should be something like:

s1 = string_sequence_to_ids(tokenized_sent, srl_data.word_dict, True)
r1 = []
...
r1.append((s1, feats, l0))
x, _, _, weights = srl_data.get_test_data(r1, batch_size=None)
srl_pred, scores = srl_pred_function(x, weights)
luheng commented 6 years ago

Great catch, thanks! Just made an update. Closing the issue.