lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

[Bug] XL-recurrence with AlibiPositionalBias and mems not working correctly #242

Closed pfeatherstone closed 3 months ago

pfeatherstone commented 5 months ago

This is a similar issue to https://github.com/lucidrains/x-transformers/issues/223 but with Alibi.

So, I am trying to do XL-recurrence with:

The repro is:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Decoder (
        dim             = 512,
        depth           = 4,
        heads           = 4,

        disable_abs_pos_emb = True,
        alibi_pos_bias  = True,
        alibi_num_heads = 2,

        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x           = torch.randn(B, 1024, 2)
length      = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask        = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems        = [torch.randn(x.shape[0], M, D) for _ in range(depth)]
mem_masks   = [torch.zeros(x.shape[0], M).bool() for _ in range(depth)]

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, mem_masks=mem_masks, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

I imagine the issue and fix is similar to RoPE.

lucidrains commented 5 months ago

@pfeatherstone yea i could make it work

however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

lucidrains commented 5 months ago

@pfeatherstone maybe i'll just make it all work as a personal challenge. i can also fix dynamic pos bias in the presence of memory as well

pfeatherstone commented 5 months ago

Basically i found rotary positional embedding to not length-extrapolate at all. Like, it's really bad. I'm doing some tests with XPOS and though I haven't got a complete model yet (still training), it looks a bit better. However I'm nervous about limiting the context length.

pfeatherstone commented 5 months ago

@pfeatherstone yea i could make it work

however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

is this based on personal tests? According to the paper it's the best thing since sliced bread

lucidrains commented 5 months ago

Basically i found rotary positional embedding to not length-extrapolate at all. Like, it's really bad. I'm doing some tests with XPOS and though I haven't got a complete model yet (still training), it looks a bit better. However I'm nervous about limiting the context length.

that's well known, i think i even mention it in the readme. however, there's a lot of research going into fine tuning trained rotary models to longer context, so it is not a big deal

lucidrains commented 5 months ago

i wouldn't use xpos either.. it suffers from the same issues as alibi. i really should start removing features i no longer believe in

lucidrains commented 5 months ago

@pfeatherstone yea i could make it work however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

is this based on personal tests? According to the paper it's the best thing since sliced bread

which paper?

pfeatherstone commented 5 months ago

https://ofir.io/train_short_test_long.pdf, the one you reference in your readme. I have to admit I haven't read it in great detail but they suggest AliBI is great.

pfeatherstone commented 5 months ago

Basically I need a positional embedding that length-extrapolates well, works with memories, and flash attention. Do you have any suggestions?

lucidrains commented 5 months ago

@pfeatherstone that's from the author of alibi. of course they would say it is great

lucidrains commented 5 months ago

Basically I need a positional embedding that length-extrapolates well, works with memories, and flash attention. Do you have any suggestions?

these days, i would stick with rotary, given the amount of research now going into it

curriculum learn to longer sequence lengths while tuning the rotary theta value (and whatever new tricks recent papers have discovered)

pfeatherstone commented 5 months ago

What do you mean by curriculum learn to longer sequence lengths? Sorry if my questions are dumb.

lucidrains commented 5 months ago

@pfeatherstone ah, curriculum learning is just a fancy way of saying making training increasingly harder over time, like how you design a curriculum for a student. so start with a small sequence length and slowly increase to your desired length

lucidrains commented 3 months ago

@pfeatherstone do you want to see if 1.28.0 fixes the issue?

MarcusLoppe commented 3 months ago

@pfeatherstone do you want to see if 1.28.0 fixes the issue?

I think that update broke something, I got this error while training the transformer in MeshGPT.

Was it the update or did the attention args started to kick in? attn_kwargs: dict = dict( ff_glu = True, num_mem_kv = 4 ), https://github.com/lucidrains/meshgpt-pytorch/blob/main/meshgpt_pytorch/meshgpt_pytorch.py#L1011C7-L1014C11

File /opt/conda/lib/python3.10/site-packages/x_transformers/x_transformers.py:950, in Attention.forward(self, x, context, mask, context_mask, attn_mask, rel_pos, rotary_pos_emb, prev_attn, mem, mem_mask, return_intermediates, cache)
    947 # append with no bias for memory key / values
    949 if has_mem_kv:
--> 950     attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
    952 # attention is all we need
    954 out, intermediates = self.attend(
    955     q, k, v,
    956     mask = final_attn_mask,
    957     attn_bias = attn_bias,
    958     prev_attn = prev_attn
    959 )

File /opt/conda/lib/python3.10/site-packages/x_transformers/x_transformers.py:97, in pad_at_dim(t, pad, dim, value)
     95 dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
     96 zeros = ((0, 0) * dims_from_right)
---> 97 return F.pad(t, (*zeros, *pad), value = value)

TypeError: pad(): argument 'input' (position 1) must be Tensor, not NoneType
lucidrains commented 3 months ago

@MarcusLoppe oh hey Marcus! fancy seeing you here 😄

should be fixed!

MarcusLoppe commented 3 months ago

@MarcusLoppe oh hey Marcus! fancy seeing you here 😄

should be fixed!

I get around 😄 Trying to find some attention that can deal with the context length effectively, maybe I'll try on that ring that you have 😄

Awesome, I'll confirm later on if it's fixed.

Edit: All good!