lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

XL-recurrence with RotaryEmbedding and mems not working correctly. #223

Closed pfeatherstone closed 6 months ago

pfeatherstone commented 6 months ago

Note, this follows on from https://github.com/lucidrains/x-transformers/issues/216

I am trying to do XL-recurrence with:

I'm doing a test which checks that the outputs when passing mems=None and mems=torch.zeros(...) are the same. They are not. I'm using the code below:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Encoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems    = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

I also tried changing https://github.com/lucidrains/x-transformers/blob/583c19dc0eb80182b0fa8ed2bfd3b22bcecbc374/x_transformers/x_transformers.py#L882-L884

to

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

but that doesn't help. any ideas?

pfeatherstone commented 6 months ago

I also tried changing:

https://github.com/lucidrains/x-transformers/blob/583c19dc0eb80182b0fa8ed2bfd3b22bcecbc374/x_transformers/x_transformers.py#L465

to

freqs = freqs[:seq_len, :]

That made more sense to me. I think this makes the results match a bit better but not perfectly.

pfeatherstone commented 6 months ago

If i set:

use_abs_pos_emb=True,
rotary_pos_emb=False

And keep the suggested change

https://github.com/lucidrains/x-transformers/blob/583c19dc0eb80182b0fa8ed2bfd3b22bcecbc374/x_transformers/x_transformers.py#L882-L884

to

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

Then it works. My understanding was that RotaryEmbedding should work in this case. Maybe not. @lucidrains can you confirm?

lucidrains commented 6 months ago

@pfeatherstone hey, does the equality work if you turn off rotary embeddings?

pfeatherstone commented 6 months ago

If I use

use_abs_pos_emb=True,
rotary_pos_emb=False

with the suggested change it works.

If I use:

rotary_pos_emb=False

it attempts to use AbsolutePositionalEmbedding which i don't really want.

lucidrains commented 6 months ago

nice yea, i think i may know what's up. will look into it when i find a stretch of free time

pfeatherstone commented 6 months ago

Can you give me a hint? I can try figure out the details

lucidrains commented 6 months ago

@pfeatherstone i think the memories should be kept at negative positions, so say you have 2 memory tokens and 5 main tokens, the positions should be [-1, -2, 0, 1, 2, 3, 4] instead of [0..7). could be wrong, need to reread my code

pfeatherstone commented 6 months ago

I will give it a go

lucidrains commented 6 months ago

@pfeatherstone what is the magnitude of the error?

pfeatherstone commented 6 months ago

The absolute error is around 0.008 on average

lucidrains commented 6 months ago

@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'

pfeatherstone commented 6 months ago

@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'

sorry, it's actually larger. More like 0.4 max absolute difference.

pfeatherstone commented 6 months ago

i'll still try what you suggested

pfeatherstone commented 6 months ago

@lucidrains Yes it worked!

So the total changes are:

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

at line 882 of x_transformers.py

and

if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
    M = max(list(map(lambda m: m.shape[1] if exists(m) else 0, mems)))
    T = x.shape[1]
    t = torch.arange(-M, T)
    rotary_pos_emb = self.rotary_pos_emb.forward(t)

at line 1257 of x_transformers.py

lucidrains commented 6 months ago

@pfeatherstone šŸ‘ šŸ’Æ you mvp

want to try submitting a PR?

pfeatherstone commented 6 months ago

@pfeatherstone šŸ‘ šŸ’Æ you mvp

want to try submitting a PR?

I can do. Without unit tests, PRs are easy ;) Only thing is that some of the changes aren't ONNX-export friendly...

pfeatherstone commented 6 months ago

Also, the line:

attend = torch.any(mem)

doesn't work if any of the batch items is non-zero. So you would need to pad differently for each batch item. I'm looking into a fix

lucidrains commented 6 months ago

ok, at the very least you got it working for your case

this isn't really that big of a deal

lucidrains commented 6 months ago

i'll make the correction for rotary when i find some time

thanks for taking the initiative and working it out

pfeatherstone commented 6 months ago

https://github.com/lucidrains/x-transformers/pull/224

pfeatherstone commented 6 months ago

So I've fixed the issue of zero mems is the same as not attending to mems at all, and correct rotary embeddings. The second issue i've come across is that mems are recorded before the pre-norm layer normalization. Yet, on the next iteration, they are prepended after. I tested it, and i was getting gibberish. I've fixed the issue by recording new mems exactly where old mems are prepended. Now, i get sensible results. FYI, i'm using sandwich norm which uses pre-LN.

lucidrains commented 6 months ago

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

lucidrains commented 6 months ago

@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a mem_mask input

pfeatherstone commented 6 months ago

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.

pfeatherstone commented 6 months ago

@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a mem_mask input

Ok cool. Though the code does create an appropriate mask. it assumes that all zeros shouldn't be attended to. I think that's a sensible default. Would someone want to explicitly attend to zeros ?

lucidrains commented 6 months ago

@pfeatherstone i don't think there would be any issue, just that a mem_mask would lead to more flexibility, and solve your problem with needing an initial zero mems, which i assume is onnx related

pfeatherstone commented 6 months ago

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.

just rerun twice with a change of a boolean and you'll have your answer

Yeah, It takes a couple days for my models to train. There is a lot of augmentation and therefore randomness. Every time i run an experiment, without changing any parameters, convergence happens at different times and of course i get wildely different results. So when i'm looking at convergence, it's hard to know if an improvement was sheer luck or a model enhancement. Stability on the other hand is pretty tied to the architecture. In my case, with or without sandwich norm, stability is the same.

lucidrains commented 6 months ago

@pfeatherstone sg

could you try the latest version? below runs fine for me now

import torch
from x_transformers import ContinuousTransformerWrapper, Encoder
from x_transformers import ContinuousAutoregressiveWrapper

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Encoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems    = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]
mem_masks = [torch.zeros(x.shape[0], M, dtype = torch.bool) for _ in range(depth)] # memory mask

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, mem_masks = mem_masks, return_mems=True)
torch.testing.assert_close(out1, out2)

for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)
lucidrains commented 6 months ago

@pfeatherstone you let me know what you see when you rerun the sandwich norm experiments. thinking about removing it

lucidrains commented 6 months ago

ok, i'm going to close this issue, i think it is good now

pfeatherstone commented 6 months ago

@pfeatherstone noticed you are using an Encoder instead of a Decoder in your example code. you have a working model based on this idea?

I'm actually using a Decoder. I used Encoder for the repro to make things simpler

lucidrains commented 6 months ago

@pfeatherstone ahh got it, you are using it correctly then, just checking

pfeatherstone commented 6 months ago

@pfeatherstone ahh got it, you are using it correctly then, just checking

Out of interest, why would it not be ok to use this with Encoder. The only difference between Encoder and Decoder is whether the mask is causal (triangular) or not. I use Decoder mainly because I don't want to attend to "future" tokens. which is desirable in a streaming architecture.

lucidrains commented 6 months ago

@pfeatherstone depends on how you are sampling it