Closed hno2 closed 3 years ago
Is it possible you could send the generator.pth so I can try your example code?
I just debugged a little further, this "mismatch". I just dug a little deeper. I forgot to add the dropout
parameter also for the resnet_generator
. I think the error message of PyTorch was not really clear to me, as it told me nothing about the layers added by dropout, Anyway thanks!
Sorry, my bad. If you do things correctly, things actually work ;)
Hey,
me again. Sorry to bother you again!
So I am trying to do inference on a trained model (with default values). I exported the Generator with the
export_generator
Function. Now I try to load my generator as shown in the Web App Example.But I get errors in loading the state_dict. The state dict seems to have extra key for nine extra layers, if I understand the error message correctly:
Error Message
``` model.load_state_dict(torch.load("generator.pth", map_location=device)) File "/Applications/Utilities/miniconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Sequential: Missing key(s) in state_dict: "10.conv_block.5.weight", "10.conv_block.5.bias", "11.conv_block.5.weight", "11.conv_block.5.bias", "12.conv_block.5.weight", "12.conv_block.5.bias", "13.conv_block.5.weight", "13.conv_block.5.bias", "14.conv_block.5.weight", "14.conv_block.5.bias", "15.conv_block.5.weight", "15.conv_block.5.bias", "16.conv_block.5.weight", "16.conv_block.5.bias", "17.conv_block.5.weight", "17.conv_block.5.bias", "18.conv_block.5.weight", "18.conv_block.5.bias". Unexpected key(s) in state_dict: "10.conv_block.6.weight", "10.conv_block.6.bias", "11.conv_block.6.weight", "11.conv_block.6.bias", "12.conv_block.6.weight", "12.conv_block.6.bias", "13.conv_block.6.weight", "13.conv_block.6.bias", "14.conv_block.6.weight", "14.conv_block.6.bias", "15.conv_block.6.weight", "15.conv_block.6.bias", "16.conv_block.6.weight", "16.conv_block.6.bias", "17.conv_block.6.weight", "17.conv_block.6.bias", "18.conv_block.6.weight", "18.conv_block.6.bias". ```
Minimal Working Example
```python import torch from upit.models.cyclegan import resnet_generator import torchvision.transforms from PIL import Image device = torch.device("cpu") model = resnet_generator(ch_in=3, ch_out=3) model.load_state_dict(torch.load("generator.pth", map_location=device)) model.eval() totensor = torchvision.transforms.ToTensor() normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) topilimage = torchvision.transforms.ToPILImage() def predict(input): im = normalize_fn(totensor(input)) print(im.shape) preds = model(im.unsqueeze(0)) / 2 + 0.5 print(preds.shape) return topilimage(preds.squeeze(0).detach().cpu()) im = predict(Image.open("test.jpg")) im.save("out.jpg") ```
Thanks again for your support!