Closed egbertYeah closed 1 year ago
Hi @egbertYeah,
Have you tried running it with the default setup (two GPUs only) ? It should get you a trained model in a few hours (less than a day).
Are you sure the error is related to adding more GPUs ?
I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :
gpus:
- 2
- 3
and Jax should automatically use the first two GPUs.
Hi @egbertYeah,
Have you tried running it with the default setup (two GPUs only) ? It should get you a trained model in a few hours (less than a day).
yes, i have trained coco datasets with default setup and it it correct.
Are you sure the error is related to adding more GPUs ?
I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :
gpus: - 2 - 3
and Jax should automatically use the first two GPUs.
thanks for your advice. since the dataset is large, so i want to use multi-gpu training. i will try your suggestion.
I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :
gpus: - 2 - 3
and Jax should automatically use the first two GPUs.
i think this maybe a error for Jax should automatically use the first two GPUs. i changed the config like this:
gpus: - 2 - 3
but i get this error
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7383586816 bytes.
i see jax use fixed gpu with below config, how can i change jax to use first and second GPUs?
loss: _target_: silk.losses.info_nce.Loss block_size: 5400 jax_device: "cuda:0" temperature: 0.1
i see jax use fixed gpu with below config, how can i change jax to use first and second GPUs? loss: target: silk.losses.info_nce.Loss block_size: 5400 jax_device: "cuda:0" temperature: 0.1
Oh yes, I forgot about that part. Looking at the code again, the jax_device
argument, given to the jax2torch
wrapper layer, can only handle one device. So there is currently no simple way to run the loss on multiple devices.
However, the error that you got from jax indicates your GPU 0 doesn't have enough memory to compute the loss. You can solve that by lowering the block_size
argument. That argument correspond the size of the blocks to process in memory when computing the large similarity matrix. A smaller block size will use less memory, but will take more time to compute. So you can try to find the highest value of that argument that doesn't trigger a memory exhaustion.
Hi, I have 3GPUs in total and how can I change jax to use my second GPU since my first GPU is not available? Thanks for your reply.
Hi @Juizai,
Setting CUDA_VISIBLE_DEVICES
and calling the training like this should work.
CUDA_VISIBLE_DEVICES=1,2 python ...
Hi @gleize, It helps! Thank you very much!
i change this config to support multi-gpus training, but i get the below error about nccl, how can i solve this problem. thanks for your reply.