facebookresearch / Mask-Predict

A masked language modeling objective to train a model to predict any subset of the target words, conditioned on both the input text and a partially masked target translation.
Other
240 stars 38 forks source link

Whether there exists the [length] token? #17

Open yfzhang114 opened 3 years ago

yfzhang114 commented 3 years ago

I have a question about the length token. I cannot find the length token in the input source sentences image

hjx999222 commented 2 years ago

Hello, I have the same problem now. Have you solved yet?

yair-schiff commented 1 year ago

I am also digging into this codebase. My current understanding is that the length token isn't part of the src input sequence but gets created during the encoder forward pass here.

len_tokens = self.embed_lengths(src_tokens.new(src_tokens.size(0), 1).fill_(0))
x = torch.cat([len_tokens, x], dim=1)

Predictions are generated later on in this method here.

predicted_lengths_logits = torch.matmul(x[0, :, :], self.embed_lengths.weight.transpose(0, 1)).float()
predicted_lengths_logits[:, 0] += float('-inf')   # Cannot predict the len_token
predicted_lengths = F.log_softmax(predicted_lengths_logits, dim=-1)

(Note that the reason, the first token (which is the appended, newly created length token is indexed like this x[0, :, :] is because the sequence gets transposed at this line: x = x.transpose(0, 1))