lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.57k stars 391 forks source link

Feature request: support return_mems in ContinuousTransformerWrapper #166

Open pfeatherstone opened 1 year ago

pfeatherstone commented 1 year ago

It would be great if ContinuousTransformerWrapper supported return_mems in the forward pass. Thank you for the awesome repo! Remarkably, it all works with torch.onnx.export()!

lucidrains commented 1 year ago

@pfeatherstone oh sure! threw it in there quickly before starting my main work

how are you using it? 🧐

lucidrains commented 1 year ago

it is actually interesting how many people have told me they are using the continuous wrapper, although there's so little research on that. it works well?

hugofloresgarcia commented 1 year ago

We use a continuous transformer in our new paper: https://arxiv.org/pdf/2307.04686.pdf for music generation and find that it works well! we use the continuous representation of the VQ-VAE latents as the continuous embeddings used as input for the transformer. Awesome work w/ this repo btw @lucidrains!

lucidrains commented 1 year ago

@hugofloresgarcia congrats on the paper!

pfeatherstone commented 1 year ago

@pfeatherstone oh sure! threw it in there quickly before starting my main work

how are you using it? monocle_face

Oh it's just my inputs are in normalized floating point format already, not tokenized. I think Wav2Vec2 is basically like that no?

pfeatherstone commented 1 year ago

@lucidrains How do you train a non-autoregressive continuous transformer with mem and return_mem. I can see in the code you have XLAutoregressiveWrapper but that's only for auto-regressive transformers, i.e. where the targets are simply the inputs left-shifted. I can also see NonAutoregressiveWrapper. I can't quite tell if that's appropriate for training recurrent transformers. Thank you in advance.

lucidrains commented 1 year ago

@pfeatherstone ohh, this repository is not well suited for custom recurrent

are you trying to do something like RMT, but non-autoregressive? does your idea resemble memformer?

pfeatherstone commented 1 year ago

Yes it's similar. To be honest I thought this repo would have done the job. Maybe I need to read up on this more to properly determine which architecture suits me best. To me the mechanism provided in this repo (return_mem and mems=), RMT and memformer all look like they are doing roughly the same thing...

Basically I want to output mem outputs from running segment (t) and feed them along with segment (t+1) to the next iteration, exactly like how return_mem works here. Difficulty is how do you train. Do you need to train with segments or not? My architecture is using CTC loss.

pfeatherstone commented 1 year ago

Basically i want a kind of stream-aware transformer with causal attention, non-autoregressive, trained with CTC loss, with an effective response that is infinite, a bit like how infinite impulse response (IIR) filters work. In transformer world, if you constantly feed mems from previous iterations, it should be able to "remember" information from the infinite past.

lucidrains commented 1 year ago

@pfeatherstone yea, i'm a big fan of the RMT architecture too

lucidrains commented 1 year ago

let me think, yea, i think x-transformers is close, since it has the ability to prepend embeddings (like PaLI)

i can take a look at this later this week

pfeatherstone commented 1 year ago

So I can see three candidate:

I'm new to recurrent transformers. I'm pretty familiar with "normal" transformers (GPT like for example), where you basically feed your entire input (text, image, or whatever). But though recurrence seems easy to design in the forward pass, I can't quite see how you train effectively (backward pass). do you need to ramdonly partition your input into segments during training and pretend you are "streaming" or is there a more elegant, less faffy, way of training

lucidrains commented 1 year ago

@pfeatherstone i think your best bet is to modify the RMT architecture

i also included the memory-replay-backprop technique from memformer, so the network can learn to formulate its memories better with little cost to hardware memory

pfeatherstone commented 1 year ago

@lucidrains There is also this paper https://arxiv.org/pdf/2109.00301.pdf, which proposes infinity-former It's a bit daunting how many ways there is to cache and reuse hidden states over time. There must be a canonical way of doing this stuff....

pfeatherstone commented 1 year ago

We just need IIR filters in neural networks...

pfeatherstone commented 1 year ago

@lucidrains have you looked at RWKV architecture? Looks like it's solving something similar. Surely all these RNN+Transformer architectures are going to converge.