Closed benemana closed 6 months ago
Given the RivaGan class's architecture, which integrates several sub-modules, a comprehensive refactoring of the codebase might be necessary to align it with your desired operational standards. Unfortunately, I don't have time to help you at the moment.
Hi, thanks for the amazing work.
I was wondering if you could help me in modifying the code so that the model is saved by using its state_dict, instead of saving the whole model, as PyTorch documentation suggests (https://pytorch.org/tutorials/beginner/saving_loading_models.html).
Since the RivaGAN class does not extend directly the nn.Module, I tought to simply modify it from
class RivaGAN(object)
toclass RivaGAN(nn.Module)
and then change from
torch.save(self, os.path.join(log_dir, "model.pt"))
to
torch.save(self.state_dict(), os.path.join(log_dir, "model.pt"))
but maybe I'm missing something and I can't figure out if it's enough for the model to work properly or if this modification could alter the model behavior when loaded (since the RivaGAN class is composed by multiple sub-modules).
Do you have any suggestion? Thanks