Open kangzhao2 opened 1 year ago
@kangzhao2 so there's actually a -1
in the padding, which removes one from the sequence dimension
It's “smeared key" architecture mentioned in https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html
Thank you for your amazing code.
In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following:
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)
Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.
In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.
May I know why it works?
Kang