Closed pfeatherstone closed 3 months ago
Presumably in apply_rotary_pos_emb()
we need to add:
scale = scale[-seq_len:, :]
?
As an aside, why is all the RotaryEmbedding decorated with @torch.cuda.amp.autocast(enabled = False)
?
You can remove it with just a couple tweaks and it supports torch.bfloat16
.
Also, I think the scale
calculation is incorrect when using mems since the positions are off.
You have to use the same trick of starting from negative position.
https://github.com/lucidrains/x-transformers/pull/234
I believe this fixes it.
great job!
Repro:
You get error: