lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Seq len missing in rotary embedding #226

Closed raganato closed 6 months ago

raganato commented 6 months ago

https://github.com/lucidrains/x-transformers/blob/49b196e8a9da707c9bf16a59f9d09ed6200dc0e7/x_transformers/x_transformers.py#L437

the forward of the RotaryEmbedding lacks the seq len input argument. I think we just need to add it, and then at line 1273 rotary_pos_emb = self.rotary_pos_emb(pos) include it as x.shape[1]

lucidrains commented 6 months ago

@raganato are you looking for https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L430

raganato commented 6 months ago

it should break with the following setting, so when the xpos is set to True and it goes in line 450 https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L450

model = TransformerWrapper(
    num_tokens = 10,
    max_seq_len = 20,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8,
        rotary_xpos = True,   # modified rotary to extrapolate well beyond length at which it was trained
    )
)
lucidrains commented 6 months ago

@raganato oops, you are correct

should be fixed, thank you!