gfxdisp / mdf

Multi-scale discriminator feature-wise loss function
BSD 3-Clause "New" or "Revised" License
102 stars 8 forks source link

Load the weights file for multi-gpu training #4

Closed rushi-the-neural-arch closed 2 years ago

rushi-the-neural-arch commented 2 years ago

Hi! I liked your approach and wanted to try it for SISR task, however, I am facing issues while loading the weights file when I want to train it on multiple GPUs. I guess the torch.load() directly is causing the problem here https://github.com/gfxdisp/mdf/blob/85c9ad36734f50e30d162e327194e8dbfa52f0cf/mdfloss.py#L10

Can you please pass it as a model instance where we can load it as model.load_state_dict(torch.load("...")) ?? That might be helpful if anyone wants to train using multiple GPUs. Also can you please share the opt argparse arguments mentioned in the SinGAN/models.py file?? https://github.com/gfxdisp/mdf/blob/85c9ad36734f50e30d162e327194e8dbfa52f0cf/SinGAN/models.py#L19

Please let me know if you have a workaround for this (loading the weights file on multiple GPUs) and it would also be helpful if you could release the training code.

Thanks!

Karan-Choudhary commented 2 years ago

I think we should set the arguments of opt argparse according to the algorithm that we select in option application= in the main configuration

aamir-mustafa commented 2 years ago

Hi. In the implementation we have provided a choice of the number of discriminators (scales) you need in your loss function. For this purpose the discriminators are loaded in the form of a list. nn.DataParallel(Ds) will not work on the list of models. However, you can implement data parallelisation of each model inside the file 'mdfloss.py'. This should be pretty straightforward, as follows:

  1. After you load each scale discriminator as D= Ds[scale]
  2. We can do D = nn.DataParallel(D, device_ids=[0, 1, 2]), depending on the number of devices you have.

Hope that helps