abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Why using different calculation methods for the key and value of the cross-attention of the decoder layer in the training and validation stages? #44

Closed jjkk123456 closed 11 months ago

jjkk123456 commented 11 months ago

For example, in the training stage, you use the SLED context chunking method, allowing the input to only enter the encoders and get the encoder last layers hidden state. Then you use the encoder last layers hidden state to calculate the long key and long value of the cross-attention of each layer of the decoder. However, in the validation stage, you directly input the input into the entire model, and then directly merge the key and value of the cross-attention layer of the decoder layer as long key and long value. I want to know the reason for calculating long key and long value in different ways in two stages.

abertsch72 commented 11 months ago

Hi, thanks for your interest in our work!

We use the same encoding method in both training and validation-- we encode the input in overlapping chunks to ensure each embedding has sufficient context. We use only embeddings from the middle of each chunk, and we take the encoder's last hidden state for each embedding.

In the code, you might see a slight difference between training and validation: at training time, we store these hidden states in a single matrix, so that we can keep the computational graph and back-propagate correctly. At validation time, we don't need to preserve a computational graph, so we can use a datastore instead, which is a bit more memory-efficient.

Does this answer your question?

jjkk123456 commented 11 months ago

That's a good trick. Thanks for your reply.