With the residual transformer, you want to predict the residual layer tokens. So I am a little bit confused as to why you take the tokens of the first $n - 1$ layers here. That is, you take the base layer + 4 residual layers. Shouldn't you take all_indices[..., 1:] to get the residual tokens?
Greetings,
With the residual transformer, you want to predict the residual layer tokens. So I am a little bit confused as to why you take the tokens of the first $n - 1$ layers here. That is, you take the base layer + 4 residual layers. Shouldn't you take
all_indices[..., 1:]
to get the residual tokens?