abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

About the method `attention_forward_hook` #30

Closed seunghyukoh closed 1 year ago

seunghyukoh commented 1 year ago

Hi, I've been reading the code for a few days, and I have a question about how the code works.

As far as I know, attention_forward_hook is the only method that looks for the relevant keys in the datastore.

However, the method only deals with local variables, except for cur_layer_key_value_placeholder, which does not really affect the output of the attention layer. It also does not return anything.

Can you explain how the stored data is used?

urialon commented 1 year ago

Hi @jake-seunghyukoh , Thank you for your interest in our work!

This is a great question, and I'd be happy to explain since we spent a lot of time on the engineering of this, and developing this with injecting "hooks", rather than changing the architecture inline (that is, instead of changing the LLama code itself).

You are right, attention_forward_hook is the function that looks for relevant keys in the datastore. The variable self.cur_layer_key_value_placeholder is a pointer to the cross attention's local variable of past_key_value.

By placing vectors and tensors in self.cur_layer_key_value_placeholder, we are actually changing the vectors that would be available to the base model's attention in its local variable past_key_value. So when the model's attention function attends to its past_key_value vectors, it is in fact attending to the vectors that we placed there in the function attention_forward_hook.

This is done as followed:

In these lines: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L561-L562 We are capturing with self.cur_layer_key_value_placeholder a pointer to the model's attention function's past_key_value variable. Originally past_key_value is a 2-tuple containing two tensors (keys and values), but a tuple is immutable, so in these lines we convert the 2-tuple into a list of length 2. It's important to convert it to a list, because you can't replace a tuple's element.

Then, in these lines: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L710-L711 in the attention_forward_hook function, we are using the same list, but we change the list's first and second elements to be new tensors that we created. This is the "injection" - we place tensors there, that will later be consumed by the base model (e.g., LLama)'s attention function.

The base model (e.g. LLama) doesn't "know" that anything has changed, it attends to the keys and values supplied to it in the past_key_value variable, as usual.

I hope that helps? Let me know if you have any more questions, I'd love to explain.

Best, Uri

seunghyukoh commented 1 year ago

Thank you so much! Amazing engineering 👍