lucidrains / recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
MIT License
393 stars 15 forks source link

Feature request: make JIT and ONNX export work #16

Open pfeatherstone opened 1 year ago

pfeatherstone commented 1 year ago
net = RecurrentMemoryTransformer(
    seq_len=1024,
    num_tokens=256,
    num_memory_tokens=128,
    dim=512,
    depth=1,
    causal=True,
    heads=4,
    dim_head=128,
    use_flash_attn=True,
    rotary_pos_emb=True
).eval()

x = torch.randint(0, 256, (8, 1024))

jit = torch.jit.trace(net, (x,))

x = torch.randint(0, 256, (8, 1024))
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = lengths_to_padding_mask(x.shape[1], l)

l1, mems, _ = net(x, mask=m)
l2, mems, _ = net(x, mems, mask=m)
l3, mems, _ = jit(x, mask=m)
l4, mems, _ = jit(x, mems, mask=m)

torch.testing.assert_close(l1, l3)
torch.testing.assert_close(l2, l4)

It would be great if the above worked.

lucidrains commented 1 year ago

@pfeatherstone ahh, yea, i can look into that

care to share what you are seeing on your dataset with this approach?

pfeatherstone commented 1 year ago

I haven't been able to train my models yet just with normal transformers, using larger context lengths (my weird TTS + STT system). CTC loss isn't converging at all. So haven't attempted a proper run with RMT architecture in the STT model. But setting it up with RMT while debugging the other one. I will let you know if i find success. I'm worried that training 2 transformers in tandem simply doesn't work for reasons. Either because of stupidly slow convergence, too lower batch size, or other reasons... Don't know. I've been looking at shifted tokens, scale_norm and other tricks to help with convergence. But i'm not getting any luck. I'm tempted to try RWKV as they claim really fast convergence. Either way, I'm going to need something like RMT in the end so i can have a well defined streaming architecture on the STT side.

lucidrains commented 1 year ago

oh got it, makes sense

pfeatherstone commented 1 year ago

Gave this a go, it turns out that torch.jit.trace() doesn't accept None in example_inputs. So we cannot trace with mems not None and expect to work when None, or vice versa. My workaround is to pass mems=torch.zeros(B,num_memory_tokens,dim) in the first pass. Which means you're attending to self.read_memory_emb ONLY in the first pass. Don't know if that's allowed.