Closed spinoza1791 closed 9 months ago
Unfortunately, only model weights are available at the moment. But you can easily derive it with some Python code (I did not check the syntax)
import torch
from modeling.anime_gan import Generator
weight = "..."
model = Generator()
checkpoint = torch.load(weight, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
model.save(model, "/path/to/full_model.pth")
Understood, but could you give an example of what to enter for these two suggested lines, as I cannot figure out what to enter for them? Thank you!
weight = "..."
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
weight = "..."
is path to model weight on your machine. something like this weight
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
to get model state_dict, just keep as it is, no need to change
Thank you, I am still getting errors however on this line: model.load_state_dict(checkpoint['model_state_dict'], strict=True)
import torch from torch import Generator
weight = 'anime.pth' model = Generator() checkpoint = torch.load(weight, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict'], strict=True) model.save(model, "anime_sd.pth")
----> 7 model.load_state_dict(checkpoint['model_state_dict'], strict=True) 8 model.save(model, "anime_sd.pth")
AttributeError: 'torch._C.Generator' object has no attribute 'load_state_dict'
Thank you for your amazing work! If your team is able to upload the full model (weights plus architecture) to torch.hub.load then the model could be converted to a mobile version. Or do you already have the pretrained model weights plus architecture file? I can see that your .pth is only the OrderedStrict weights.