abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Use of other Encode/Decoder Models #55

Open rdmerillat opened 10 months ago

rdmerillat commented 10 months ago

Hello, I've been using Unlimiformer as a comparison with current standard methods of summarization and was wondering if there was anything in particular that would be needed in order to convert say a Pegasus model into Unlimiformer as it should work with "All Encoder/Decoder" models. I see several lines commented out in unlimiformer.py (here) for AutoModelForSeq2Seq, however I currently dont see a direct way this has been implemented yet.

As Pegasus is BART based, I set up a new model conveter PegasusForConditionalGeneration: UnlimiformerPegasus, and started a new unlimiformer class for it:

class UnlimiformerPegasus(UnlimiformerBART):
    def __init__(self, model: PegasusModel, *args, **kwargs):
        super().__init__(model, *args, **kwargs)

However, I was wondering if you or anyone else had found additional tweeking that was needed to fully convert say a pegasus model.

And I guess more generally, what is the procedure that you use when setting up your own new unlimiformer converted models as I was unable to simply glean what was necessary to assure "consistent" performance and or results.

Thanks!

urialon commented 10 months ago

Hi @rdmerillat , Thank you for your interest in our work!

We haven't tried pegasus, but the solution you described sounds correct. Make sure to add the new class to the type_to_class dictionary.

Please let us know how it goes!

Best, Uri

patrickocal commented 10 months ago

Hi @urialon, I've been comparing your UnlimiformerBART and UnlimiformerLLaMa classes. Am I right in thinking that the latter projects queries whereas the former projects keys and values? Assuming I'm right, why the difference? The LlaMa approach seems in line with the paper. Many thanks, Patrick

urialon commented 10 months ago

Hi @patrickocal , What exactly do you mean by "projects keys" and "projects queries"?

I'd say that both of them project queries, as described in the paper, with some unrelated difference at Llama due to its Rotary Position Embeddings.

Which difference between them are you concerned about?

Best, Uri

rdmerillat commented 10 months ago

Please let us know how it goes!

Will do!

patrickocal commented 10 months ago

Hi @urialon, thanks for the quick reply and thanks for sharing all your good work. By way of background, I am taking Jure Leskovec's course on Machine Learning with Graphs and my team's project is: how to integrate Knowledge Graphs with your approach (any suggestions would be welcome).

Regarding:

Hi @patrickocal , What exactly do you mean by "projects keys" and "projects queries"?

I'd say that both of them project queries, as described in the paper, with some unrelated difference at Llama due to its Rotary Position Embeddings.

Which difference between them are you concerned about?

Best, Uri

I guess I was referring to your python terminology as per methods below. Formally, by projected query, I mean the first term in the inner product $\langle h_d W_q W_k^T, h_e \rangle$. By projection of a keys I mean the second term in the inner product $\langle h_d, W_q^T W_k h_e \rangle$.

My understanding is that the step where the query projected in UnlimiformerLLaMa is in the method:

    def preprocess_query(self, query, k_proj_weight):

        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        This method is projecting the query using the weight matrix k_proj_weight.
        This projection is well-aligned with the paper's description
        where the query is projected using a product of weight matrices before
        the kNN search.
        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        # query: (batch * time, head, dim)
        attention = self.model.base_model.layers[-1].self_attn
        num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1],
                            self.actual_model_window_size)
        cos, sin = attention.rotary_emb(query, seq_len=num_generated)
        cos = cos[:,:,-1]  # [1, 1, dim]
        sin = sin[:,:,-1]  # [1, 1, dim]
        # cos = cos[-1].unsqueeze(0).unsqueeze(0)  # [bs, 1, seq_len, dim]
        # sin = sin[-1].unsqueeze(0)  # [bs, 1, seq_len, dim]
        query = (query * cos) + (self.rotate_half(query) * sin)

        k_proj = k_proj_weight.view(1,
                                    self.num_heads,
                                    query.shape[-1],
                                    k_proj_weight.shape[0]
                                    ) # (1, num_heads, attn_dim, embed_dim)
        k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :]
        k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :]
        k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2)

        datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        THE FOLLOWING STEP
        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        datastore_query = torch.matmul(datastore_query,
                                       k_proj + k_proj_rotated
                                       ) # (batch * beam, num_heads, 1, embed_dim)
        datastore_query = datastore_query.squeeze(-2)  # (batch * beam, num_heads, embed_dim)
        return datastore_query

As far as I can see, the k_proj is only applied to the key in UnlimiformerBART in the following method:

    def create_key_value(self, encoder_hidden_states, decoder_layer):
        # (batch, time, hidden_dim)
        attention = decoder_layer.encoder_attn #-------attention function here
        # key, value: (batch, heads, time, attn_dim)
        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        THE FOLLOWING STEP
        """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
        key = attention.k_proj(encoder_hidden_states)
        key = key.view(key.shape[0],
                       -1,
                       attention.num_heads,
                       attention.head_dim).transpose(1, 2).contiguous()
        value = attention.v_proj(encoder_hidden_states)
        value = value.view(value.shape[0],
                           -1,
                           attention.num_heads,
                           attention.head_dim).transpose(1, 2).contiguous()
        # key, value: (batch, heads, time, attn_dim)
        return key, value 

Thus, my understanding of UnlimiformerBART is that the $Q K^T$ of the attention mechanism is of the form $\langle h_d W_q, W_k h_e\rangle$. What implications this has for the storage and in relation to the attention reformulation in your paper, I am not sure. I'm just doing my best to wrap my head around your work.

Thanks!

urialon commented 10 months ago

Hi @patrickocal ,

The attention reformulation trick that we highlight in the paper is used mostly for inference time. At training time, we do compute the "standard" attention where both the key and the query are projected. The function create_key_value is used only at training time, if I'm not mistaken.

At test time, we indeed project the query, and keep the key without projection in the datastore. In LLaMA, because of the RoPE embeddings, we need to do some processing to the keys after retrieving them.

Does that help? Please let us know if you have any more questions.

Best, Uri

patrickocal commented 10 months ago

Thanks again for the quick reply, @urialon. I'm afraid do have more questions :) So thanks in advance for your patience. My understanding is that, within the reset_memory method of the Unlimiformer class (where I have deleted unrelated lines):

class Unlimiformer(Generic[ModelType]):
    ...
    def reset_memory(self, input_ids, attention_mask):
        ...
        for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
            logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}')
            if self.use_datastore:
                # TODO: verify with BART as well
                # hidden_states_to_index = [hidden_states.encoder_last_hidden_state]
                hidden_states_to_index = [layer_capturer.captured for layer_capturer in self.activation_capturer] 
                # hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end]
                to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index]
                if not self.reconstruct_embeddings:
                    to_add_embeddings = to_add
                    for i, layer_states in enumerate(to_add_embeddings):
                        self.hidden_states[i].append(layer_states.to(self.datastore_device))

we see that states are added whenever self.use_datastore == True. There is a separate toggle for the case where either self.use_datastore == False or self.test_datastore == True:

            if (not self.use_datastore) or self.test_datastore:
                layers_kv = [
                    self.process_key_value(layer_capturer) # (batch, head, time, dim)
                    for layer_capturer in self.activation_capturer
                    ] # list of pairs of (batch, head, time, dim)
                key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
                value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv]

My understanding is therefore, that it is only a FAISS index of the state, $h_e$ that gets added to the datastore. (Thus, as per the paper, neither the keys $W_k h_e$, nor the values $W_v h_e$ are added to the datastore: regardless of whether we are testing or training.)

So, please correct me if I'm wrong, but, in all cases, when we conduct a $k$-nearest neighbors search of the datastore, we need to first project the decoder state $h_d$ to the space $\mathcal H_e$ (where the encoded states live) and conduct our search there. To do so, we need to project using $W_q W_k^T$ (formerly it's the adjoint of this operator, but let's not worry about that).

On the other hand, I can see that it is keys and queries that get passed into the faiss.knn_gpu method within the Datastore class. (Though I am not sure whether keys and queries are indeed of the form $W_k h_e$ and $W_q h_d$ respectively.) I am still not fully understanding how the add_keys method of the Datstore class gets its contents from the steps above in the Unlimiformer class(es), but I'm getting there. Any pointers would be much appreciated.

I may be out of action on Friday, but I look forward to your response.

Thanks Uri!

Patrick

urialon commented 10 months ago

My understanding is therefore, that it is only a FAISS index of the state, $h_e$ that gets added to the datastore. (Thus, as per the paper, neither the keys $W_k h_e$, nor the values $W_v h_e$ are added to the datastore: regardless of whether we are testing or training.)

This is correct:

  1. We use a FAISS datastore only at test time
  2. Whenever we use a FAISS datastore, only $h_e$ states are added to the datastore - that's the main insight of the Attention Reformulation section (section 2.3) in the paper.
  3. We use the explicit keys $W_k h_e$ and values $W_v h_e$ either at (a) training time; (2) test time, when use_datastore=False - which is a faster option but more GPU memory consuming

in all cases, when we conduct a $k$-nearest neighbors search of the datastore, we need to first project the decoder state $h_d$ to the space $\mathcal H_e$ (where the encoded states live) and conduct our search there. To do so, we need to project using $W_q W_k^T$ (formerly it's the adjoint of this operator, but let's not worry about that).

This is correct - that's the main insight of the Attention Reformulation section in the paper.

On the other hand, I can see that it is keys and queries that get passed into the faiss.knn_gpu method within the Datastore class. (Though I am not sure whether keys and queries are indeed of the form $W_k h_e$ and $W_q h_d$ respectively.) I am still not fully understanding how the add_keys method of the Datstore class gets its contents from the steps above in the Unlimiformer class(es), but I'm getting there. Any pointers would be much appreciated.

Whenever keys and queries are passed into the Datastore class, they are in the "right space". I think about the Datastore class as a separate generic module, that cares about keys and queries. It's the Unlimiformer code's responsibility to project them and pass them in the right way. So, whenever "keys" are added to the Datastore class, they are basically $h_e$, and queries are $h_d W_k h_e$.

I am still not fully understanding how the add_keys method of the Datstore class gets its contents from the steps above in the Unlimiformer class(es), but I'm getting there. Any pointers would be much appreciated.

In the Unlimiformer code, we call train_index here: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L433 which calls index.train_index here: https://github.com/abertsch72/unlimiformer/blob/main/src/index_building.py#L29 which calls add_keys here: https://github.com/abertsch72/unlimiformer/blob/main/src/index_building.py#L79-L95

I hope it helps, feel free to ask any questions! Uri