ShannonAI / fast-knn-nmt

Apache License 2.0
32 stars 4 forks source link

How do you bypass memory limitations on the WMT dataset? #7

Closed SirRob1997 closed 2 years ago

SirRob1997 commented 2 years ago

This isn't really related to your code but: Since the WMT dataset for the vanilla knn-MT has a pretty big datastore (I think in the ballpark of 900M keys), the corresponding FAISS index is quite big and loading it to a single GPU together with the model is a pretty big issue. Naturally, you'd want to shard the index over multiple GPUs but FAISS has some problems searching over a sharded index with GPU tensors, see: https://github.com/facebookresearch/faiss/issues/2074 or https://github.com/facebookresearch/faiss/issues/1997. Also loading it to a separate GPU, transferring tensors there for searching, and moving the results back, runs into some memory deallocation issues inside FAISS.

I wonder how you ran the experiments on that or does your codebase has some neat trick that I didn't catch yet?

From my understanding you don't shard as well, according to this line and then you also transfer the tensors you search on to CPU here (you can pass GPU tensors btw if you import faiss.contrib.torch_utils that changes the function signatures to accept GPU tensors instead of np arrays). Does your index for WMT with vanilla knn-MT just fit in GPU memory?

For the domain dataset, the datastores are relatively small so this isn't a big issue and I got them set-up just fine!

YuxianMeng commented 2 years ago

@SirRob1997 Hi, as you may noticed, we find KNN on source side before running generate.py, which makes tuning hyperparameters much faster. Therefore we could load and find knn for each token in vocabulary one by one, instead of loading all of them into the GPU. The relevant code is [here].(https://github.com/ShannonAI/fast-knn-nmt/blob/main/fast_knn_nmt/knn/find_knn_neighbors.py)

YuxianMeng commented 2 years ago

@SirRob1997 BTW, since we build faiss-index for each token in vocabulary, I think it's much easier to load them into multiple GPUs than sharding one faiss-index into multiple GPUs?

SirRob1997 commented 2 years ago

Cool, thanks for the input!

If anyone finds this issue: I've got it running following the pointers outlined in the FAISS issues linked above.