UT-Austin-RPL / MUTEX

MIT License
41 stars 3 forks source link

Questions about Multi-GPU training #7

Open lipeiyan-bd opened 1 week ago

lipeiyan-bd commented 1 week ago

I'm sorry to trouble you again, but could you explain how to implement multi-GPU training with this code? I noticed that the training speed with 8 GPUs seems to be the same as when using just one GPU. Also, is there an implementation available that utilizes PyTorch's DistributedDataParallel?

ShahRutav commented 1 week ago

No worries, happy to help!

noticed that the training speed with 8 GPUs seems to be the same as when using just one GPU

Which GPU are you using? The trade-off of using multiple GPUs might not be visible if your individual GPU is powerful enough. I have never tried with much newer GPUs like the Hopper or Blackwell series.

Also, is there an implementation available that utilizes PyTorch's DistributedDataParallel?

I have not tried it in the MUTEX codebase. It might give some improvement.