ptran1203 / pytorch-animeGAN

Pytorch implementation of AnimeGAN for fast photo animation
161 stars 39 forks source link

Please create the full model (weights plus architecture) to torch.hub.load #18

Closed spinoza1791 closed 9 months ago

spinoza1791 commented 2 years ago

Thank you for your amazing work! If your team is able to upload the full model (weights plus architecture) to torch.hub.load then the model could be converted to a mobile version. Or do you already have the pretrained model weights plus architecture file? I can see that your .pth is only the OrderedStrict weights.

ptran1203 commented 2 years ago

Unfortunately, only model weights are available at the moment. But you can easily derive it with some Python code (I did not check the syntax)

import torch
from modeling.anime_gan import Generator

weight = "..."
model = Generator()
checkpoint = torch.load(weight,  map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
model.save(model, "/path/to/full_model.pth")
spinoza1791 commented 2 years ago

Understood, but could you give an example of what to enter for these two suggested lines, as I cannot figure out what to enter for them? Thank you!

weight = "..."

model.load_state_dict(checkpoint['model_state_dict'], strict=True)

ptran1203 commented 2 years ago

weight = "..." is path to model weight on your machine. something like this weight model.load_state_dict(checkpoint['model_state_dict'], strict=True) to get model state_dict, just keep as it is, no need to change

spinoza1791 commented 2 years ago

Thank you, I am still getting errors however on this line: model.load_state_dict(checkpoint['model_state_dict'], strict=True)

Any ideas?

import torch from torch import Generator

weight = 'anime.pth' model = Generator() checkpoint = torch.load(weight, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict'], strict=True) model.save(model, "anime_sd.pth")

----> 7 model.load_state_dict(checkpoint['model_state_dict'], strict=True) 8 model.save(model, "anime_sd.pth")

AttributeError: 'torch._C.Generator' object has no attribute 'load_state_dict'