Closed mmehta-navomi closed 4 years ago
Our repo does not support multiple GPU training correctly at the moment, we will look to fix it soon.
If your training runs out of memory using a single GPU, you can try to reduce the batch size. On colab (K80 GPU) it works with a batch size of 8.
@andrelmfarias Thanks for responding. However, this is happening during fine tuning process when I run fit_reader
. do I need to update batch size for that also?
Yes, actually the batch_size
only affects the reader training / fine-tunning.
There's no real training with fit_retriever
.
To change the batch_size you can do:
cdqa_pipeline.reader.train_batch_size = 8
# proceed with reader fine-tuning
cdqa_pipeline.fit_reader('path-to-custom-squad-like-dataset.json')
That worked. Will look forward for multi GPU support :) Thanks.
Hey, @andrelmfarias same error is seen while running cdqa_pipeline.predict
also.
Just to note I process running on cloud GPU machine and packages are in python virtualenv, if I don't use virtualenv than cdqa_pipeline.predict
works fine. Any idea?
Resolution is to set torch.cuda.set_device(0)
(not best practice but since there is no multi GPU supports).
If I would like to run on CPU, How should I do ? I train on GPU on AWS, and run cdqa_pipeline.predict had the same error. Only cuda:3 is replaced to cpu
I did something like this to check for GPU and if not available use CPU.
if torch.cuda.is_available():
print('**** Great, CUDA is Available ****')
reader.device = torch.device("cuda")
reader.model.to('cuda')
reader.device = torch.device('cuda')
# Currently GPU count = 1 as Multip GPU is NOT supported
# reader.n_gpu = torch.cuda.device_count() // enable this when multi GPU supports arrive
reader.n_gpu = 1
else:
print('*** Training with CPU...might take a while ****')
reader.model.to('cpu')
reader.device = torch.device('cpu')
Thank you ! I found answer at https://github.com/cdqa-suite/cdQA/issues/238 ~~
Error while fine tuning with multiple GPUs. While running with single GPU the it's running out of memory. Any clue why would this error being thrown.