Open Maram-Helmy opened 6 months ago
Hi! While the current implementation only uses 1 GPU by default, you can look into PyTorch Distributed Data Parallel (https://pytorch.org/docs/stable/notes/ddp.html) for multi-GPU support. You would also need to add some logic for checking the local_rank of the device and doing some operations in a non-parallel fashion (e.g., wandb logging or model checkpointing).
If folks have a strong interest in multi-GPU support please feel free to comment / like this thread, and I can start a branch for this.
Thank you for your work!
How can I run this code on multiple gpus?