Open rom1504 opened 2 years ago
You are giving batches to index.search and index.add right? Faiss has good built-in support for parallism (implemented in c++)
Thousand queries should take like 10ms with hnsw
OK, that's fair, it should be batched. Different elements will have to take different paths in the HNSW, but at least you have the loop in C++ and not in Python.
Nevertheless, HNSW is still the bottleneck. I guess it must be the add() that's the bottleneck, not the search. Because we also have to add thousands of entries into the HNSW graph for each training iteration. And add in HNSW is a lot slower than search.
it was fast enough for me to complete one experiment, and the results don't look that great tbh
i'll keep playing around with it this week. i just need to see some spark before i dedicate more engineering effort to this idea
perhaps https://github.com/lucidrains/HTM-pytorch would be a better bet, with the memories kept on CPU, and then the compressed memories kept in GPU (and i can reach for https://github.com/lucidrains/memory-efficient-attention-pytorch if need be), but that memory transformer variant rightfully belongs in another repository, as a Deepmind contribution
on further reflection, i don't even think the attention layer for the long term memories should mix local with distant memories (using all the gating logic)
one should simply place a long term memory module dab in the middle of the network, cross attend efficiently once
A few comments:
Regarding your implementation:
On my end, I am experimenting with these ideas, but taking a different path than you. Instead of trying to implement the whole paper, I'm starting from a pre-trained model and trying to implement trivial memorization in the smallest incremental steps I can think of. Hopefully at least one of us gets to some kind of an interesting results.
@igor0 ahh yes, i do like the idea of having separate heads for local vs distant
re: null key / values, the reason is because in the paper, one can have documents spanning different lengths across different batch rows
it is just cleaner solution than to unbind the batch dimension and do all the conditional logic. the current way it is built is still not great, since ideally the number of memories present is also accounted for in the gating (so the network attends locally when no distant memories are present for a certain batch row)
@igor0 do keep us posted about results from your approach :)
@igor0 i switched it to your splitting of attention heads idea and it looks a lot better now!! https://wandb.ai/lucidrains/memorizing-transformers/reports/memorizing-transformers--VmlldzoxODE4ODc1?accessToken=2ae4y9l100i3kfj3oa0udcsct0r495evo661lkk71w1mve1itlmphn20eiix1aul 🎉 🎉 thank you!
@rom1504 also thank you for all your help with faiss and the hnsw suggestion!
the effects are more pronounced when i turn off XL memories - will investigate for bugs in the XL memories code later this week
Interesting. I wonder whether the baseline got an unlucky initialization and so it had a poor start. I'm not sure I'd expect memories to be of that much help early on during the training. If the benefits of memorization show up later in the training, that may make the approach harder to evaluate.
On further thought, I'm not sure how exactly to make a dedicated distant attention head work. Since the memory access is non-differentiable, the backward pass won't reach the K and V transforms. So, the K and V transforms for the distant heads won't be trained. So, at least the K/V transform must be tied across the local-remote modes. Did you solve that in some way?
@igor0 i think that's the appeal of this paper. it seems the claim is that even with non differentiable memory, in addition to approximate nearest neighbors, the network would still perform well. i'm fairly confident the non-differentiable memories would do fine, as it lines up with transformer-xl. XL memories are also non-differentiable (they are detached for the next iteration)
however, one could augment the network with another projection of the non-differentiable memory key / values, perhaps conditioned by their age (which we can keep track of)
Well in the paper, each KV projection is used for both local and distant attention. So, even though the backward pass for distant attention doesn't reach the KV projection, that is still OK because the backward passes from local attention seem to be sufficient.
But, if you have a distant-only head, there will literally be no backward passes reaching the KV projections. So, their weights will forever keep their initial random values. That can't possibly work.
Anyways, it's well possible that I'm missing something, but that's why it seems to me that distant-only attention heads won't work trivially. Using a hybrid local-distant head seems to be one solution.
ahhh got it, yes, we can try the hybrid approach too (which i assume is just concatting the fetched distant key / values to the local one, prior to softmax)
ultimately, i think we should just let empiricism speak. hard to know which option is best without running the experiments
I'll have some free time Sunday and will run some experiments comparing the two ways
@igor0 Hey Igor, I finally got some time to attend to memorizing transformers. Mainly I tried out the hybrid attention as you suggested and it is working the best! I've checked it in for 0.3.0
the reprojection of the memories did very little, possibly a little worse, just to share a negative result
finally seeing liftoff with 3 modifications to the KNN attention module
I can suggest give a try to dynamic position bias instead of T5 rel pos, as later is usually slow and DPB can give similar performance for better speed.
@gurvindersingh Oh hey Gurvinder! How is it faster than T5 rel pos? The way I see it, T5 rel pos is a zero-layered DPB (usually DPB is a 2-3 layered MLP)
Very cool!
Yeah, a full cosine-similarity attention makes sense. Normalizing key and values like in the paper sounds... odd. I'm sure the researchers behind the paper had their rationale.
On my end, I've been trying to add memorization to an existing pretrained model. The positive result is that if I add memory to one or more layers in a pretrained model as hybrid attention, I get a modest but measurable performance benefit. I've been trying to take such model (with one KNN layer) and train it (e.g., with attentions unfrozen, or everything unfrozen) to better adapt the model to exchange info in the KNN layer.
I haven't had much success with the retraining so far. The retraining tends to be unstable and quickly diverge. I've tried a number of the things you are trying here (project the memorized keys/values, add a trained bias to memorized keys, try it with positional biases for memory or without, etc). I can normalize keys/queries or keys/values in the KNN layer, but then that degrades the pretrained model substantially. It seems to stabilize the retraining somewhat, but then I lose some of the benefit of starting from a pretrained model.
The only thing I found that sort of worked was that training a bias for memorized keys gives a modest but measurable benefit. But I haven't been able to stabilize the training sufficiently to get really compelling results. Still working on it and trying things, though! A couple of ideas I'd still like to try include suppressing the KNN attention early in the training (e.g, by using a trained penalty coefficient applied to the distant attention scores before the distant+local softmax) or perhaps adding a dedicated layer just for distant attention.
@igor0 ahh ok, yes, i did have to do a lot of trial and error before i started to see something with the formula outlined above
i'll keep chipping away at it when i find some time
@lucidrains I haven't looked into the code part but when I was testing T5 rel pos embedding and DPB, with DPB model was taking less time per step. Hence my above statement :)
@gurvindersingh ohh interesting! i'll have to revisit the T5 rel pos bias class and make sure i'm caching some intermediates correctly
I'm gonna piggyback this thread, because it is quite active.
A memory_list is only needed when I have multiple Layers that contain a KNNAttentionblock right? So if I have only one layer, I can just use one instance of KNNMemory and don't need to organize them in a list.
Thanks in advance!
Hey! Cool repo. I like all the knn+lm methods Did you do some runs yet? Anything interesting to report?