Open rdmerillat opened 10 months ago
Hi @rdmerillat , Thank you for your interest in our work!
We haven't tried pegasus, but the solution you described sounds correct. Make sure to add the new class to the type_to_class
dictionary.
Please let us know how it goes!
Best, Uri
Hi @urialon, I've been comparing your UnlimiformerBART and UnlimiformerLLaMa classes. Am I right in thinking that the latter projects queries whereas the former projects keys and values? Assuming I'm right, why the difference? The LlaMa approach seems in line with the paper. Many thanks, Patrick
Hi @patrickocal , What exactly do you mean by "projects keys" and "projects queries"?
I'd say that both of them project queries, as described in the paper, with some unrelated difference at Llama due to its Rotary Position Embeddings.
Which difference between them are you concerned about?
Best, Uri
Please let us know how it goes!
Will do!
Hi @urialon, thanks for the quick reply and thanks for sharing all your good work. By way of background, I am taking Jure Leskovec's course on Machine Learning with Graphs and my team's project is: how to integrate Knowledge Graphs with your approach (any suggestions would be welcome).
Regarding:
Hi @patrickocal , What exactly do you mean by "projects keys" and "projects queries"?
I'd say that both of them project queries, as described in the paper, with some unrelated difference at Llama due to its Rotary Position Embeddings.
Which difference between them are you concerned about?
Best, Uri
I guess I was referring to your python terminology as per methods below. Formally, by projected query, I mean the first term in the inner product $\langle h_d W_q W_k^T, h_e \rangle$. By projection of a keys I mean the second term in the inner product $\langle h_d, W_q^T W_k h_e \rangle$.
My understanding is that the step where the query projected in UnlimiformerLLaMa is in the method:
def preprocess_query(self, query, k_proj_weight):
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
This method is projecting the query using the weight matrix k_proj_weight.
This projection is well-aligned with the paper's description
where the query is projected using a product of weight matrices before
the kNN search.
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# query: (batch * time, head, dim)
attention = self.model.base_model.layers[-1].self_attn
num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1],
self.actual_model_window_size)
cos, sin = attention.rotary_emb(query, seq_len=num_generated)
cos = cos[:,:,-1] # [1, 1, dim]
sin = sin[:,:,-1] # [1, 1, dim]
# cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim]
# sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim]
query = (query * cos) + (self.rotate_half(query) * sin)
k_proj = k_proj_weight.view(1,
self.num_heads,
query.shape[-1],
k_proj_weight.shape[0]
) # (1, num_heads, attn_dim, embed_dim)
k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :]
k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :]
k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
THE FOLLOWING STEP
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
datastore_query = torch.matmul(datastore_query,
k_proj + k_proj_rotated
) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
return datastore_query
As far as I can see, the k_proj is only applied to the key in UnlimiformerBART in the following method:
def create_key_value(self, encoder_hidden_states, decoder_layer):
# (batch, time, hidden_dim)
attention = decoder_layer.encoder_attn #-------attention function here
# key, value: (batch, heads, time, attn_dim)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
THE FOLLOWING STEP
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
key = attention.k_proj(encoder_hidden_states)
key = key.view(key.shape[0],
-1,
attention.num_heads,
attention.head_dim).transpose(1, 2).contiguous()
value = attention.v_proj(encoder_hidden_states)
value = value.view(value.shape[0],
-1,
attention.num_heads,
attention.head_dim).transpose(1, 2).contiguous()
# key, value: (batch, heads, time, attn_dim)
return key, value
Thus, my understanding of UnlimiformerBART is that the $Q K^T$ of the attention mechanism is of the form $\langle h_d W_q, W_k h_e\rangle$. What implications this has for the storage and in relation to the attention reformulation in your paper, I am not sure. I'm just doing my best to wrap my head around your work.
Thanks!
Hi @patrickocal ,
The attention reformulation trick that we highlight in the paper is used mostly for inference time. At training time, we do compute the "standard" attention where both the key and the query are projected. The function create_key_value
is used only at training time, if I'm not mistaken.
At test time, we indeed project the query, and keep the key without projection in the datastore. In LLaMA, because of the RoPE embeddings, we need to do some processing to the keys after retrieving them.
Does that help? Please let us know if you have any more questions.
Best, Uri
Thanks again for the quick reply, @urialon. I'm afraid do have more questions :) So thanks in advance for your patience. My understanding is that, within the reset_memory method of the Unlimiformer class (where I have deleted unrelated lines):
class Unlimiformer(Generic[ModelType]):
...
def reset_memory(self, input_ids, attention_mask):
...
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}')
if self.use_datastore:
# TODO: verify with BART as well
# hidden_states_to_index = [hidden_states.encoder_last_hidden_state]
hidden_states_to_index = [layer_capturer.captured for layer_capturer in self.activation_capturer]
# hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end]
to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index]
if not self.reconstruct_embeddings:
to_add_embeddings = to_add
for i, layer_states in enumerate(to_add_embeddings):
self.hidden_states[i].append(layer_states.to(self.datastore_device))
we see that states are added whenever self.use_datastore == True
. There is a separate toggle for the case where either self.use_datastore == False
or self.test_datastore == True
:
if (not self.use_datastore) or self.test_datastore:
layers_kv = [
self.process_key_value(layer_capturer) # (batch, head, time, dim)
for layer_capturer in self.activation_capturer
] # list of pairs of (batch, head, time, dim)
key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
My understanding is therefore, that it is only a FAISS index of the state, $h_e$ that gets added to the datastore. (Thus, as per the paper, neither the keys $W_k h_e$, nor the values $W_v h_e$ are added to the datastore: regardless of whether we are testing or training.)
So, please correct me if I'm wrong, but, in all cases, when we conduct a $k$-nearest neighbors search of the datastore, we need to first project the decoder state $h_d$ to the space $\mathcal H_e$ (where the encoded states live) and conduct our search there. To do so, we need to project using $W_q W_k^T$ (formerly it's the adjoint of this operator, but let's not worry about that).
On the other hand, I can see that it is keys and queries that get passed into the faiss.knn_gpu
method within the Datastore
class. (Though I am not sure whether keys and queries are indeed of the form $W_k h_e$ and $W_q h_d$ respectively.) I am still not fully understanding how the add_keys
method of the Datstore
class gets its contents from the steps above in the Unlimiformer
class(es), but I'm getting there. Any pointers would be much appreciated.
I may be out of action on Friday, but I look forward to your response.
Thanks Uri!
Patrick
My understanding is therefore, that it is only a FAISS index of the state, $h_e$ that gets added to the datastore. (Thus, as per the paper, neither the keys $W_k h_e$, nor the values $W_v h_e$ are added to the datastore: regardless of whether we are testing or training.)
This is correct:
use_datastore=False
- which is a faster option but more GPU memory consumingin all cases, when we conduct a $k$-nearest neighbors search of the datastore, we need to first project the decoder state $h_d$ to the space $\mathcal H_e$ (where the encoded states live) and conduct our search there. To do so, we need to project using $W_q W_k^T$ (formerly it's the adjoint of this operator, but let's not worry about that).
This is correct - that's the main insight of the Attention Reformulation section in the paper.
On the other hand, I can see that it is keys and queries that get passed into the
faiss.knn_gpu
method within theDatastore
class. (Though I am not sure whether keys and queries are indeed of the form $W_k h_e$ and $W_q h_d$ respectively.) I am still not fully understanding how theadd_keys
method of theDatstore
class gets its contents from the steps above in theUnlimiformer
class(es), but I'm getting there. Any pointers would be much appreciated.
Whenever keys and queries are passed into the Datastore
class, they are in the "right space".
I think about the Datastore
class as a separate generic module, that cares about keys and queries. It's the Unlimiformer code's responsibility to project them and pass them in the right way.
So, whenever "keys" are added to the Datastore
class, they are basically $h_e$, and queries are $h_d W_k h_e$.
I am still not fully understanding how the
add_keys
method of theDatstore
class gets its contents from the steps above in theUnlimiformer
class(es), but I'm getting there. Any pointers would be much appreciated.
In the Unlimiformer
code, we call train_index
here: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L433
which calls index.train_index
here: https://github.com/abertsch72/unlimiformer/blob/main/src/index_building.py#L29
which calls add_keys
here: https://github.com/abertsch72/unlimiformer/blob/main/src/index_building.py#L79-L95
I hope it helps, feel free to ask any questions! Uri
Hello, I've been using Unlimiformer as a comparison with current standard methods of summarization and was wondering if there was anything in particular that would be needed in order to convert say a Pegasus model into Unlimiformer as it should work with "All Encoder/Decoder" models. I see several lines commented out in
unlimiformer.py
(here) for AutoModelForSeq2Seq, however I currently dont see a direct way this has been implemented yet.As Pegasus is BART based, I set up a new model conveter
PegasusForConditionalGeneration: UnlimiformerPegasus,
and started a new unlimiformer class for it:However, I was wondering if you or anyone else had found additional tweeking that was needed to fully convert say a pegasus model.
And I guess more generally, what is the procedure that you use when setting up your own new unlimiformer converted models as I was unable to simply glean what was necessary to assure "consistent" performance and or results.
Thanks!