tmabraham / UPIT

A fastai/PyTorch package for unpaired image-to-image translation.
https://tmabraham.github.io/UPIT
Apache License 2.0
133 stars 21 forks source link

Inference - Can not load state_dict #20

Closed hno2 closed 3 years ago

hno2 commented 3 years ago

Hey,

me again. Sorry to bother you again!

So I am trying to do inference on a trained model (with default values). I exported the Generator with the export_generatorFunction. Now I try to load my generator as shown in the Web App Example.

But I get errors in loading the state_dict. The state dict seems to have extra key for nine extra layers, if I understand the error message correctly:

Error Message

``` model.load_state_dict(torch.load("generator.pth", map_location=device)) File "/Applications/Utilities/miniconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Sequential: Missing key(s) in state_dict: "10.conv_block.5.weight", "10.conv_block.5.bias", "11.conv_block.5.weight", "11.conv_block.5.bias", "12.conv_block.5.weight", "12.conv_block.5.bias", "13.conv_block.5.weight", "13.conv_block.5.bias", "14.conv_block.5.weight", "14.conv_block.5.bias", "15.conv_block.5.weight", "15.conv_block.5.bias", "16.conv_block.5.weight", "16.conv_block.5.bias", "17.conv_block.5.weight", "17.conv_block.5.bias", "18.conv_block.5.weight", "18.conv_block.5.bias". Unexpected key(s) in state_dict: "10.conv_block.6.weight", "10.conv_block.6.bias", "11.conv_block.6.weight", "11.conv_block.6.bias", "12.conv_block.6.weight", "12.conv_block.6.bias", "13.conv_block.6.weight", "13.conv_block.6.bias", "14.conv_block.6.weight", "14.conv_block.6.bias", "15.conv_block.6.weight", "15.conv_block.6.bias", "16.conv_block.6.weight", "16.conv_block.6.bias", "17.conv_block.6.weight", "17.conv_block.6.bias", "18.conv_block.6.weight", "18.conv_block.6.bias". ```

Minimal Working Example

```python import torch from upit.models.cyclegan import resnet_generator import torchvision.transforms from PIL import Image device = torch.device("cpu") model = resnet_generator(ch_in=3, ch_out=3) model.load_state_dict(torch.load("generator.pth", map_location=device)) model.eval() totensor = torchvision.transforms.ToTensor() normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) topilimage = torchvision.transforms.ToPILImage() def predict(input): im = normalize_fn(totensor(input)) print(im.shape) preds = model(im.unsqueeze(0)) / 2 + 0.5 print(preds.shape) return topilimage(preds.squeeze(0).detach().cpu()) im = predict(Image.open("test.jpg")) im.save("out.jpg") ```

Thanks again for your support!

tmabraham commented 3 years ago

Is it possible you could send the generator.pth so I can try your example code?

hno2 commented 3 years ago

I just debugged a little further, this "mismatch". I just dug a little deeper. I forgot to add the dropout parameter also for the resnet_generator. I think the error message of PyTorch was not really clear to me, as it told me nothing about the layers added by dropout, Anyway thanks!

Sorry, my bad. If you do things correctly, things actually work ;)