lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
344 stars 24 forks source link

About the "shift_tokens" #5

Open kangzhao2 opened 1 year ago

kangzhao2 commented 1 year ago

Thank you for your amazing code.

In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

May I know why it works?

Kang

lucidrains commented 1 year ago

@kangzhao2 so there's actually a -1 in the padding, which removes one from the sequence dimension

liujuncn commented 1 year ago

It's “smeared key" architecture mentioned in https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html