facebookresearch / silk

SiLK (Simple Learned Keypoint) is a self-supervised deep learning keypoint model.
GNU General Public License v3.0
644 stars 58 forks source link

mulit-gpus training #44

Closed egbertYeah closed 1 year ago

egbertYeah commented 1 year ago

i change this config to support multi-gpus training, but i get the below error about nccl, how can i solve this problem. thanks for your reply.

# need two GPUs, PyTorch use GPU 1, and Jax uses GPU 0.
  gpus:
    - 1
    - 2
    - 3
RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1169, invalid usage, NCCL version 21.0.3
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
gleize commented 1 year ago

Hi @egbertYeah,

Have you tried running it with the default setup (two GPUs only) ? It should get you a trained model in a few hours (less than a day).

Are you sure the error is related to adding more GPUs ?

I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :

  gpus:
    - 2
    - 3

and Jax should automatically use the first two GPUs.

egbertYeah commented 1 year ago

Hi @egbertYeah,

Have you tried running it with the default setup (two GPUs only) ? It should get you a trained model in a few hours (less than a day).

yes, i have trained coco datasets with default setup and it it correct.

Are you sure the error is related to adding more GPUs ?

I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :

  gpus:
    - 2
    - 3

and Jax should automatically use the first two GPUs.

thanks for your advice. since the dataset is large, so i want to use multi-gpu training. i will try your suggestion.

egbertYeah commented 1 year ago

I wouldn't advise changing that setting since it was tuned to handle the model computation on GPU 1, and the expensive loss on GPU 0. Any other setup might require a fair amount of change in the code. Also, if you have 4 GPUs in total, it might be better to give two to Jax, and two to PyTorch. The config above would then look like :

  gpus:
    - 2
    - 3

and Jax should automatically use the first two GPUs.

i think this maybe a error for Jax should automatically use the first two GPUs. i changed the config like this:

  gpus:
    - 2
    - 3

but i get this error

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7383586816 bytes. i see jax use fixed gpu with below config, how can i change jax to use first and second GPUs? loss: _target_: silk.losses.info_nce.Loss block_size: 5400 jax_device: "cuda:0" temperature: 0.1

gleize commented 1 year ago

i see jax use fixed gpu with below config, how can i change jax to use first and second GPUs? loss: target: silk.losses.info_nce.Loss block_size: 5400 jax_device: "cuda:0" temperature: 0.1

Oh yes, I forgot about that part. Looking at the code again, the jax_device argument, given to the jax2torch wrapper layer, can only handle one device. So there is currently no simple way to run the loss on multiple devices. However, the error that you got from jax indicates your GPU 0 doesn't have enough memory to compute the loss. You can solve that by lowering the block_size argument. That argument correspond the size of the blocks to process in memory when computing the large similarity matrix. A smaller block size will use less memory, but will take more time to compute. So you can try to find the highest value of that argument that doesn't trigger a memory exhaustion.

Juizai commented 10 months ago

Hi, I have 3GPUs in total and how can I change jax to use my second GPU since my first GPU is not available? Thanks for your reply.

gleize commented 10 months ago

Hi @Juizai,

Setting CUDA_VISIBLE_DEVICES and calling the training like this should work.

CUDA_VISIBLE_DEVICES=1,2 python ...
Juizai commented 10 months ago

Hi @gleize, It helps! Thank you very much!