Open Smu-Tan opened 2 years ago
Thanks for your interest in my work!
As a sanity check step, can you try training bert-base-multilingual-uncased
with grad cache disabled and compare memory usage against some of the other models'. I think we need to first make sure that the regression does not come from mBERT.
Thanks for your interest in my work!
As a sanity check step, can you try training
bert-base-multilingual-uncased
with grad cache disabled and compare memory usage against some of the other models'. I think we need to first make sure that the regression does not come from mBERT.
Hi, indeed, I tried without grad_cache and it requires smaller batch_size, thus the grad_cache should works for m-bert as well.
Another question: Does the train_dense_retriever
support multi-gpu training as well? Coz m-bert requires more memories, I think using multiple-GPU might helps. I tried to use python -m torch.distributed.launch --nproc_per_node=4 train_dense_retriever.py
for both with and without grad_cache, and I got the RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
.
It was working the last time we check. Maybe take a look at https://github.com/luyug/GC-DPR/issues/2 where we had a discussion on parallelism.
If it does not help, please paste the entire error log here. I will take a look. Also, make sure this is not an issue unique to mBERT
.
Alternatively, we have been developing another dense retrieval toolkit called Tevatron, where I have built in native multi-GPU/TPU grad cache support. Maybe take a look at that to see it can better align with your need.
Hi,
I found a weird thing that if using the multilingual-bert e.g: bert-base-multilingual-uncased, it seems like the grad_cache doesn't work. I know it sounds weird, changing different bert models shouldn't affect it, but the thing is I tried normal bert, german bert, and m-bert, only the latter one need very small batch_size (like 4) to successfully run. Other models like german bert runs with batch_size=128 successfully. Do you probably know the reason of this? Btw, great paper and code, extremely helpful! Thanks in advance!