Closed raganato closed 6 months ago
@raganato are you looking for https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L430
it should break with the following setting, so when the xpos is set to True and it goes in line 450 https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L450
model = TransformerWrapper(
num_tokens = 10,
max_seq_len = 20,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8,
rotary_xpos = True, # modified rotary to extrapolate well beyond length at which it was trained
)
)
@raganato oops, you are correct
should be fixed, thank you!
https://github.com/lucidrains/x-transformers/blob/49b196e8a9da707c9bf16a59f9d09ed6200dc0e7/x_transformers/x_transformers.py#L437
the forward of the RotaryEmbedding lacks the seq len input argument. I think we just need to add it, and then at line 1273 rotary_pos_emb = self.rotary_pos_emb(pos) include it as x.shape[1]