clovaai / bros

Apache License 2.0
155 stars 23 forks source link

Correct implementation of RelationExtractor #10

Open ndgnuh opened 2 years ago

ndgnuh commented 2 years ago

I find the implementation of RelationExtractor in this repository is incorrect (according to the original one). I'm aware that the implementation is (kind of) the same as the one 2005.00642. But after digging up the code in clovaai/spade, I realized the original implementation is different from the paper. I'll refer to the implementation as SPADE, this repository as BROS and the paper as SPADE paper.

  1. SPADE and SPADE paper use two score matrix for each relation, BROS only have one.
  2. SPADE paper and BROS use threshold to binarize and obtain adjacency matrix; SPADE use element wise argmax of two scores matrix, so each score matrix is similar to the probability of edges or not.
  3. The loss function in SPADE is weighted cross entropy, with heavy weight toward the second score matrix (having edge).

My version of RelationExtractor (which have been tested and able to achieve somewhat equivalent results of the original SPADE):

class RelationTagger(nn.Module):
    def __init__(self, n_fields, hidden_size):
        super().__init__()
        self.head = nn.Linear(hidden_size, hidden_size)
        self.tail = nn.Linear(hidden_size, hidden_size)
        self.field_embeddings = nn.Parameter(
            torch.rand(1, n_fields, hidden_size))
        self.W_label_0 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_label_1 = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, enc):

        enc_head = self.head(enc)
        enc_tail = self.tail(enc)

        batch_size = enc_tail.size(0)
        field_embeddings = self.field_embeddings.expand(batch_size, -1, -1)
        enc_head = torch.cat([field_embeddings, enc_head], dim=1)

        score_0 = torch.matmul(
            enc_head, self.W_label_0(enc_tail).transpose(1, 2))
        score_1 = torch.matmul(
            enc_head, self.W_label_1(enc_tail).transpose(1, 2))

        score = torch.cat([score_0.unsqueeze(1), score_1.unsqueeze(1)], dim=1)
        return score

This implementation works for single relation, but one can use multiple instances of this layer for multiple relations. The output dim is b * s * (n+f) * n, where b is batch size, s = 2 and s is the number of score matrices, n is sequence length, and f is the number of fields. The final relation matrices is obtained by score.argmax(dim=1).

tghong commented 2 years ago

Hi, thank you for sharing your code. As we noted in the paper, the SPADE decoder implementation in this repo is slightly different from the original paper.

ndgnuh commented 2 years ago

Well, I can't say which one is "more correct" since I'm not the author. I just want to note that the implementation in this repository can produce very large logits, so the use of threshold (0.5 as in the paper) might not work for sparse label.