NVlabs / edm

Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
Other
1.42k stars 147 forks source link

ERROR when using multi-gpu training #18

Closed tsWen0309 closed 1 year ago

tsWen0309 commented 1 year ago

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution?

image
LYMDLUT commented 1 year ago

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution? image

I have met the same bug

christopher-beckham commented 1 year ago

I had a similar issue, make sure you're using PyTorch 1.12 as per the environment.yml file.

tsWen0309 commented 1 year ago

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution? image

I have met the same bug

I solved the problem by using PyTorch 1.13.1 with cuda 11.6 cudnn8.3.2.0

RachelTeamo commented 10 months ago

I tried this code, when I set --nproc_per_node=1, the code works fine, but once --nproc_per_node>1 (e.g.,--nproc_per_node=2), this code doesn't work and reports the same error as in the picture, is there a solution for this? My torch version is 2.1 because I'm using H800GPU.

marcoamonteiro commented 9 months ago

@RachelTeamo to run on torch 2.1 replace line 89 in training_loop.py: ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) with ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[dist.get_rank()], broadcast_buffers=False)

I couldn't find any info in the PyTorch docs warning about the change in DDP API but this solved the issue for me.

RachelTeamo commented 8 months ago

Thanks for your suggestions, I replaced the code follow your suggestion. But the issue still exist.

Shiien commented 8 months ago

I solved this by ignoring line 79-84

# if dist.get_rank() == 0:
#     with torch.no_grad():
#         images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
#         sigma = torch.ones([batch_gpu], device=device)
#         labels = torch.zeros([batch_gpu, net.label_dim], device=device)
#         misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)

And set ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[dist.get_rank()], broadcast_buffers=False)