lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
620 stars 46 forks source link

Any interesting results? #1

Open rom1504 opened 2 years ago

rom1504 commented 2 years ago

Hey! Cool repo. I like all the knn+lm methods Did you do some runs yet? Anything interesting to report?

rom1504 commented 2 years ago

You are giving batches to index.search and index.add right? Faiss has good built-in support for parallism (implemented in c++)

rom1504 commented 2 years ago

Thousand queries should take like 10ms with hnsw

igor0 commented 2 years ago

OK, that's fair, it should be batched. Different elements will have to take different paths in the HNSW, but at least you have the loop in C++ and not in Python.

Nevertheless, HNSW is still the bottleneck. I guess it must be the add() that's the bottleneck, not the search. Because we also have to add thousands of entries into the HNSW graph for each training iteration. And add in HNSW is a lot slower than search.

lucidrains commented 2 years ago

it was fast enough for me to complete one experiment, and the results don't look that great tbh

i'll keep playing around with it this week. i just need to see some spark before i dedicate more engineering effort to this idea

lucidrains commented 2 years ago

perhaps https://github.com/lucidrains/HTM-pytorch would be a better bet, with the memories kept on CPU, and then the compressed memories kept in GPU (and i can reach for https://github.com/lucidrains/memory-efficient-attention-pytorch if need be), but that memory transformer variant rightfully belongs in another repository, as a Deepmind contribution

lucidrains commented 2 years ago

on further reflection, i don't even think the attention layer for the long term memories should mix local with distant memories (using all the gating logic)

one should simply place a long term memory module dab in the middle of the network, cross attend efficiently once

igor0 commented 2 years ago

A few comments:

Regarding your implementation:

On my end, I am experimenting with these ideas, but taking a different path than you. Instead of trying to implement the whole paper, I'm starting from a pre-trained model and trying to implement trivial memorization in the smallest incremental steps I can think of. Hopefully at least one of us gets to some kind of an interesting results.

lucidrains commented 2 years ago

@igor0 ahh yes, i do like the idea of having separate heads for local vs distant

re: null key / values, the reason is because in the paper, one can have documents spanning different lengths across different batch rows

161446440-57ffa2e6-3ee6-44f1-815c-994a6ccf524c

it is just cleaner solution than to unbind the batch dimension and do all the conditional logic. the current way it is built is still not great, since ideally the number of memories present is also accounted for in the gating (so the network attends locally when no distant memories are present for a certain batch row)

lucidrains commented 2 years ago

@igor0 do keep us posted about results from your approach :)

lucidrains commented 2 years ago

@igor0 i switched it to your splitting of attention heads idea and it looks a lot better now!! https://wandb.ai/lucidrains/memorizing-transformers/reports/memorizing-transformers--VmlldzoxODE4ODc1?accessToken=2ae4y9l100i3kfj3oa0udcsct0r495evo661lkk71w1mve1itlmphn20eiix1aul 🎉 🎉 thank you!

lucidrains commented 2 years ago

@rom1504 also thank you for all your help with faiss and the hnsw suggestion!

lucidrains commented 2 years ago

Screenshot from 2022-04-11 20-37-46

the effects are more pronounced when i turn off XL memories - will investigate for bugs in the XL memories code later this week

igor0 commented 2 years ago

Interesting. I wonder whether the baseline got an unlucky initialization and so it had a poor start. I'm not sure I'd expect memories to be of that much help early on during the training. If the benefits of memorization show up later in the training, that may make the approach harder to evaluate.

igor0 commented 2 years ago

On further thought, I'm not sure how exactly to make a dedicated distant attention head work. Since the memory access is non-differentiable, the backward pass won't reach the K and V transforms. So, the K and V transforms for the distant heads won't be trained. So, at least the K/V transform must be tied across the local-remote modes. Did you solve that in some way?

lucidrains commented 2 years ago

@igor0 i think that's the appeal of this paper. it seems the claim is that even with non differentiable memory, in addition to approximate nearest neighbors, the network would still perform well. i'm fairly confident the non-differentiable memories would do fine, as it lines up with transformer-xl. XL memories are also non-differentiable (they are detached for the next iteration)

however, one could augment the network with another projection of the non-differentiable memory key / values, perhaps conditioned by their age (which we can keep track of)

igor0 commented 2 years ago

Well in the paper, each KV projection is used for both local and distant attention. So, even though the backward pass for distant attention doesn't reach the KV projection, that is still OK because the backward passes from local attention seem to be sufficient.

But, if you have a distant-only head, there will literally be no backward passes reaching the KV projections. So, their weights will forever keep their initial random values. That can't possibly work.

Anyways, it's well possible that I'm missing something, but that's why it seems to me that distant-only attention heads won't work trivially. Using a hybrid local-distant head seems to be one solution.

lucidrains commented 2 years ago

ahhh got it, yes, we can try the hybrid approach too (which i assume is just concatting the fetched distant key / values to the local one, prior to softmax)

lucidrains commented 2 years ago

ultimately, i think we should just let empiricism speak. hard to know which option is best without running the experiments

lucidrains commented 2 years ago

I'll have some free time Sunday and will run some experiments comparing the two ways

lucidrains commented 2 years ago

@igor0 Hey Igor, I finally got some time to attend to memorizing transformers. Mainly I tried out the hybrid attention as you suggested and it is working the best! I've checked it in for 0.3.0

lucidrains commented 2 years ago

the reprojection of the memories did very little, possibly a little worse, just to share a negative result

lucidrains commented 2 years ago
Screen Shot 2022-04-23 at 5 10 24 PM

finally seeing liftoff with 3 modifications to the KNN attention module

  1. use cosine-sim attention with low initial temperature
  2. hybrid attention (thanks for the suggestion 🙏 )
  3. use separate relative positional bias (as opposed to regular attention layers) - this should also allow the network to turn off attending locally later into training, if need be
gurvindersingh commented 2 years ago

I can suggest give a try to dynamic position bias instead of T5 rel pos, as later is usually slow and DPB can give similar performance for better speed.

lucidrains commented 2 years ago

@gurvindersingh Oh hey Gurvinder! How is it faster than T5 rel pos? The way I see it, T5 rel pos is a zero-layered DPB (usually DPB is a 2-3 layered MLP)

igor0 commented 2 years ago

Very cool!

Yeah, a full cosine-similarity attention makes sense. Normalizing key and values like in the paper sounds... odd. I'm sure the researchers behind the paper had their rationale.

On my end, I've been trying to add memorization to an existing pretrained model. The positive result is that if I add memory to one or more layers in a pretrained model as hybrid attention, I get a modest but measurable performance benefit. I've been trying to take such model (with one KNN layer) and train it (e.g., with attentions unfrozen, or everything unfrozen) to better adapt the model to exchange info in the KNN layer.

I haven't had much success with the retraining so far. The retraining tends to be unstable and quickly diverge. I've tried a number of the things you are trying here (project the memorized keys/values, add a trained bias to memorized keys, try it with positional biases for memory or without, etc). I can normalize keys/queries or keys/values in the KNN layer, but then that degrades the pretrained model substantially. It seems to stabilize the retraining somewhat, but then I lose some of the benefit of starting from a pretrained model.

The only thing I found that sort of worked was that training a bias for memorized keys gives a modest but measurable benefit. But I haven't been able to stabilize the training sufficiently to get really compelling results. Still working on it and trying things, though! A couple of ideas I'd still like to try include suppressing the KNN attention early in the training (e.g, by using a trained penalty coefficient applied to the distant attention scores before the distant+local softmax) or perhaps adding a dedicated layer just for distant attention.

lucidrains commented 2 years ago

@igor0 ahh ok, yes, i did have to do a lot of trial and error before i started to see something with the formula outlined above

i'll keep chipping away at it when i find some time

gurvindersingh commented 2 years ago

@lucidrains I haven't looked into the code part but when I was testing T5 rel pos embedding and DPB, with DPB model was taking less time per step. Hence my above statement :)

lucidrains commented 2 years ago

@gurvindersingh ohh interesting! i'll have to revisit the T5 rel pos bias class and make sure i'm caching some intermediates correctly

kevinpl07 commented 2 years ago

I'm gonna piggyback this thread, because it is quite active.

A memory_list is only needed when I have multiple Layers that contain a KNNAttentionblock right? So if I have only one layer, I can just use one instance of KNNMemory and don't need to organize them in a list.

Thanks in advance!