Open yfzhang114 opened 3 years ago
Hello, I have the same problem now. Have you solved yet?
I am also digging into this codebase. My current understanding is that the length
token isn't part of the src input sequence but gets created during the encoder forward
pass here.
len_tokens = self.embed_lengths(src_tokens.new(src_tokens.size(0), 1).fill_(0))
x = torch.cat([len_tokens, x], dim=1)
Predictions are generated later on in this method here.
predicted_lengths_logits = torch.matmul(x[0, :, :], self.embed_lengths.weight.transpose(0, 1)).float()
predicted_lengths_logits[:, 0] += float('-inf') # Cannot predict the len_token
predicted_lengths = F.log_softmax(predicted_lengths_logits, dim=-1)
(Note that the reason, the first token (which is the appended, newly created length
token is indexed like this x[0, :, :]
is because the sequence gets transposed at this line: x = x.transpose(0, 1)
)
I have a question about the length token. I cannot find the length token in the input source sentences