abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Encoder Only Unlimiformer #21

Closed YHL04 closed 1 year ago

YHL04 commented 1 year ago

I'm currently trying to modify the idea such that it can work with encoder only models. How would the faiss work to make it efficient? Does it have to be retrained at each timestep? Thank you.

urialon commented 1 year ago

Hi @YHL04 , Thank you for your interest in our work!

In principle, the idea of offloading the attention to a kNN-index can be performed on the self-attention of the encoder.

But I don't understand your question about "retrained at each timestep", and the question in your email about "Can it work like a recurrent neural network where at each timestep, the keys and values are added to the index?". -- How encoder-only models are like RNNs?

Best, Uri

YHL04 commented 1 year ago

I'm currently implementing a modified version where the encoder-only bert would look at books one sequence length at a time (e.g. 512) and it would be predicting the [mask] tokens at each timestep (like a transformer xl with ideas from memorizing transformer and unlimiformer), so the hidden states would have to be added after every timestep. I'm not very familiar with faiss, since in this case you are adding the hidden states on the fly, instead of the encoder-decoder method, does that mean the index will have to be retrained every timestep?

YHL04 commented 1 year ago

what i have so far: https://github.com/StockAIHub/transformergallery/blob/main/transformer/attention/knnattention.py

urialon commented 1 year ago

I'm not sure, actually. From my experience with Faiss, you can train an initial index on an initial part of the data, and then add keys without training. However, this requires the initial part of the data to be sufficiently "representative".

Maybe you could aggregate, say, 512x4 keys, use "exact kNN search" until that point, then train an index, and start adding keys afterwards?

YHL04 commented 1 year ago

Thanks, I'm going to look at memorizing transformers to see how they did it on the fly, they use an approximate search which I assume helps a lot