zoubohao / DenoisingDiffusionProbabilityModel-ddpm-

This may be the simplest implement of DDPM. You can directly run Main.py to train the UNet on CIFAR-10 dataset and see the amazing process of denoising.
MIT License
1.48k stars 156 forks source link

Visualization problem when training with multiple GPUs #33

Open llbbcc opened 8 months ago

llbbcc commented 8 months ago

Thanks for your codes! When I train the model with multiple GPUs with net_model = torch.nn.DataParallel(net_model), images obtained from sampling are noise images. I load the model with model = torch.nn.DataParallel(model) model.load_state_dict(ckpt). Has anyone encountered this problem?