lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
624 stars 46 forks source link

Dimensionality of key and values for Attention #5

Open manestay opened 2 years ago

manestay commented 2 years ago

I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).

The relevant line is: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L135

  1. Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
  2. Why is the last dimension dim_head*2? I get that 2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e. `inner_dim==dim_headheads`). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?

In your Attention class for Performer, q, k, v all have the same dimensions.

Thanks in advance!

manestay commented 2 years ago

I guess this commit cites the paper that does 1 headed attention: https://github.com/lucidrains/memorizing-transformers-pytorch/commit/9f77fd5e4e449d70c02b9cd25a98e1d5ef5f0a72

lucidrains commented 2 years ago

@manestay yup, one headed key / values is an old Noam Shazeer paper seeing a resurgence in usage by LLMs such as Alphacode and the 500B parameter PaLM model. it is usually used to save on amount of keys / values cached during inference, but it is a good fit here since we don't need to keep track of head times the faiss indices

manestay commented 2 years ago

Thanks! What about this question: Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?

lucidrains commented 2 years ago

@manestay it comes out to be faster if you do one matrix multiplication and then break it up later

manestay commented 2 years ago

I see, thanks!

An unrelated question, just to confirm my understanding, regarding the following line: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/knn_memory.py#L153 Can I ask why we have num_indices (i.e. batch_size) number of KNN objects for each KNNMemory? What does each KNN hold that is different from the other ones? And how does this interact with KNNMemory.add and KNNMemory.search, which will add/search each key/query to a different KNN?

Thanks in advance again @lucidrains

manestay commented 2 years ago

~To provide more context of the KNNMemory.add function, here's an example on my understanding:~

https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/knn_memory.py#L201

~Suppose we have in the batch (size=4) the key vectors corresponding to the sentences:~

<s> hello
<s> goodbye and
<s> ok .
<s>

~When the above line is called, it will add each key to a different KNN. So I don't get why this is the case -- don't we want all of the keys in the memory? If we have in a later batch a query vector corresponding to <s> goodbye and good, it seems that the most relevant entry is in the 2nd KNN. But we have no guarantee that this query will be in the 2nd position of the batch.~

~Is my above understanding correct? If so, then I don't see why we have multiple memories, as the paper did not mention that. If not, then please correct me. Thank you.~

EDIT: I see your comment https://github.com/lucidrains/memorizing-transformers-pytorch/issues/1#issuecomment-1086901581 . You said that "they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size." I guess I am missing something fundamental in my understanding of the whole memorizing transformers approach, since I don't see where they are doing that. Can you point me to the place in the original text? Sorry for asking so many questions.

adumit commented 2 years ago

@manestay I think Figure 3 in the paper answers your question. Each batch appears to steam documents to maintain a consistent within-document memory. And then, a batch of size B contains chunks from B distinct documents that each have their own memory.

manestay commented 2 years ago

@manestay I think Figure 3 in the paper answers your question. Each batch appears to steam documents to maintain a consistent within-document memory. And then, a batch of size B contains chunks from B distinct documents that each have their own memory.

Okay, I see what you're saying. I was having some trouble interpreting this figure, but your explanation makes sense. Thanks!

I guess I was confused by the train.py script in this repo, since it doesn't handle document level stuff, just loads enwik8 sequentially in chunks. But I do see that this is a WIP repo, so maybe that is yet to be implemented. Appreciated once again!