lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
685 stars 146 forks source link

Interpreting results #35

Closed valtheval closed 3 years ago

valtheval commented 3 years ago

Dear @markus-eberts,

Thanks for your work, that is a great job. I am facing issue with output interpretation after Evaluator.eval_batch method and model output that are:

result = model(...)
entity_clf, rel_clf, rels = result

and

print(evaluator._pred_entities)
[[(1, 2, <spert.entities.EntityType at 0x15b817bf6d0>, 0.5499340891838074),
  (2, 3, <spert.entities.EntityType at 0x15b817bf880>, 0.43182992935180664),
  (3, 4, <spert.entities.EntityType at 0x15b817bf6d0>, 0.6427331566810608),
  (4, 5, <spert.entities.EntityType at 0x15b817bf6d0>, 0.4554755389690399),..]

Could you explain the information on the different dimensions of the tensors entity_clf, rel_clf, rels (what do each dimension correspond to) and how to read the second output (especially if it corresponds to raw tokens index or bert tokens index)

Thanks in advance !

markus-eberts commented 3 years ago

Hi,

regarding the output interpretation: entity_clf is a (B x S x E) tensor where B is the batch size (in sentences), S the count of all token spans up to a specified length (10 per default) and E the number of entity types (+1 for "None"). It contains the model's (softmax) confidences that a given span belongs to a certain type of E. In case the span is assigned to the "None" type (=no entity), it is disregarded in the relation extraction step.

rel_clf is a (B x P x R) tensor where B is again the batch size, P the count of all entity pairs (= spans not assigned to the None class) and R the number of relation types. For each entity pair, it contains the (sigmoid) scores for each relation type.

rels is a (B x P x 2) tensor that contains the corresponding entity indices (in entity_clf, entity_masks, entity_sizes etc.) for each entity pair. With this, you can for example access the corresponding entity scores (by indexing entity_clf with rels).

Regarding your second question: It is (span start, span end, entity type, score). Here 'span end' is exclusive. Also, it corresponds to BPE tokens (byte-pair encoded, as in BERT), not to raw tokens.

valtheval commented 3 years ago

Thank you Markus for your help