lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
691 stars 148 forks source link

What is the meaning of the dataset tensors? #56

Closed celsofranssa closed 3 years ago

celsofranssa commented 3 years ago

When iterating over dataset samples we have the fowling dictionary of tensors:

for sample in train_dataset:
    for key in sample.keys():
        print(f"\n{key} ({sample[key].shape}):\n {sample[key]}")

# encodings (torch.Size([35])):
#  tensor([  101,  3780,  1036,  7607,  1005,  1057,  1012,  1055,  1012,  5426,
#          2930,  2824, 13109, 16932,  2692, 28332, 15136,  2683,  2549, 15278,
#          2557,  2128,  4135,  3501,  2897,  1999,  3009, 12875,  2692, 13938,
#          2102,  2410, 13114,  6365,   102])

# context_masks (torch.Size([35])):
#  tensor([True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True])

# entity_masks (torch.Size([105, 35])):
#  tensor([[False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False, False],
#         ...,
#         [False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False, False],
#         [False,  True,  True,  ..., False, False, False]])

# entity_sizes (torch.Size([105])):
#  tensor([ 1,  1,  3,  2,  3,  8,  9,  8,  2,  4,  4,  1,  9,  1,  4, 10, 10,  1,
#          1, 10,  8,  2,  4,  2,  3,  3,  6,  1,  2,  6,  1, 10, 10,  9,  3,  5,
#          3,  8,  8,  5,  1,  3,  1,  3,  5,  7,  8,  1,  3,  5,  2,  7,  8,  6,
#         10,  4,  4,  7,  3,  5,  5,  8,  6,  5,  8,  2,  6,  4,  6,  9,  9,  9,
#         10,  7,  1,  7,  9, 10,  5,  5,  2,  3,  1,  3,  7,  3,  5,  2,  6,  2,
#          7,  8,  3,  1,  6,  2,  1,  4,  6, 10,  4,  1,  7,  6,  9])

# entity_types (torch.Size([105])):
#  tensor([1, 1, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0, 0])

# rels (torch.Size([20, 2])):
#  tensor([[2, 1],
#         [0, 1],
#         [1, 0],
#         [1, 2],
#         [2, 0],
#         [1, 4],
#         [1, 3],
#         [0, 2],
#         [3, 1],
#         [2, 4],
#         [2, 3],
#         [3, 4],
#         [3, 2],
#         [4, 3],
#         [4, 2],
#         [0, 4],
#         [4, 1],
#         [4, 0],
#         [0, 3],
#         [3, 0]])

# rel_masks (torch.Size([20, 35])):
#  tensor([[False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#           True,  True,  True,  True,  True,  True,  True, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#           True,  True,  True,  True,  True,  True,  True, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False,  True,  True, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False,  True,  True, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True, False, False, False,
#          False, False, False, False, False],
#         [False, False, False, False, False, False, False, False, False,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True, False, False, False,
#          False, False, False, False, False]])

# rel_types (torch.Size([20, 5])):
#  tensor([[0., 0., 1., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])

# entity_sample_masks (torch.Size([105])):
#  tensor([True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True, True])

# rel_sample_masks (torch.Size([20])):
#  tensor([True, True, True, True, True, True, True, True, True, True, True, True,
#         True, True, True, True, True, True, True, True])

Could you provide the meaning of these tensors? For instance, encodings and context-mask maps directly to input_ids and attention_mask of BERT forward method. Therefore what are the semantics of the others tensors?

markus-eberts commented 3 years ago

Hi,

entity_masks is a ExC tensor (E := number of positive+negative entity mention samples, C := context size), used for accessing tokens belonging to an entity span (...and masking any other token).

entity_sizes is a tensor of size E, containing the size of each entity mention span (which is later mapped to an embedding)

entity_types is a tensor of size E, containing the id of the corresponding entity type (also mapped to an embedding)

rels is a Rx2 tensor (R := number of positive+negative relation samples, i.e. pairs of related (or unrelated) entity mentions), which contains the indices of corresponding entity mentions in entity_masks (and entity_size + entity_types). Used to retrieve entity mention representations for each pair after max-pooling is applied via entity_masks.

rel_masks is a RxC tensor, used to access the tokens between two entity mention (and mask any other token).

rel_types is a RxT tensor (T := number of relation types), which contains the multi-hot-encoding of relation types for each pair (all 0 -> strong negative sample)

entity_sample_masks is a tensor of size E, used for masking 'padding' entity mention samples (since we need to introduce 'padding' mentions due to batching over sentences)

relation_sample_masks is a tensor of size R, used for masking 'padding' relation samples (since we need to introduce 'padding' relations due to batching over sentences)