Peachypie98 / RivaGAN

RivaGAN: Robust Invisible Video Watermarking with Attention
12 stars 0 forks source link

Saving the model with state_dict #3

Closed benemana closed 6 months ago

benemana commented 7 months ago

Hi, thanks for the amazing work.

I was wondering if you could help me in modifying the code so that the model is saved by using its state_dict, instead of saving the whole model, as PyTorch documentation suggests (https://pytorch.org/tutorials/beginner/saving_loading_models.html).

Since the RivaGAN class does not extend directly the nn.Module, I tought to simply modify it from class RivaGAN(object) to class RivaGAN(nn.Module)

and then change from

torch.save(self, os.path.join(log_dir, "model.pt"))

to

torch.save(self.state_dict(), os.path.join(log_dir, "model.pt"))

but maybe I'm missing something and I can't figure out if it's enough for the model to work properly or if this modification could alter the model behavior when loaded (since the RivaGAN class is composed by multiple sub-modules).

Do you have any suggestion? Thanks

Peachypie98 commented 7 months ago

Given the RivaGan class's architecture, which integrates several sub-modules, a comprehensive refactoring of the codebase might be necessary to align it with your desired operational standards. Unfortunately, I don't have time to help you at the moment.