knazeri / edge-connect

EdgeConnect: Structure Guided Image Inpainting using Edge Prediction, ICCV 2019 https://arxiv.org/abs/1901.00212
http://openaccess.thecvf.com/content_ICCVW_2019/html/AIM/Nazeri_EdgeConnect_Structure_Guided_Image_Inpainting_using_Edge_Prediction_ICCVW_2019_paper.html
Other
2.5k stars 530 forks source link

Missing keys in state_dict? #121

Open laolongboy opened 4 years ago

laolongboy commented 4 years ago

RuntimeError: Error(s) in loading state_dict for EdgeGenerator: Missing key(s) in state_dict: "encoder.1.weight", "encoder.4.weight", "encoder.7.weight", "middle.0.conv_block.1.weight", "middle.0.conv_block.5.weight", "middle.1.conv_block.1.weight", "middle.1.conv_block.5.weight", "middle.2.conv_block.1.weight", "middle.2.conv_block.5.weight", "middle.3.conv_block.1.weight", "middle.3.conv_block.5.weight", "middle.4.conv_block.1.weight", "middle.4.conv_block.5.weight", "middle.5.conv_block.1.weight", "middle.5.conv_block.5.weight", "middle.6.conv_block.1.weight", "middle.6.conv_block.5.weight", "middle.7.conv_block.1.weight", "middle.7.conv_block.5.weight", "decoder.0.weight", "decoder.3.weight". Unexpected key(s) in state_dict: "encoder.1.weight_v", "encoder.4.weight_v", "encoder.7.weight_v", "middle.0.conv_block.1.weight_v", "middle.0.conv_block.5.weight_v", "middle.1.conv_block.1.weight_v", "middle.1.conv_block.5.weight_v", "middle.2.conv_block.1.weight_v", "middle.2.conv_block.5.weight_v", "middle.3.conv_block.1.weight_v", "middle.3.conv_block.5.weight_v", "middle.4.conv_block.1.weight_v", "middle.4.conv_block.5.weight_v", "middle.5.conv_block.1.weight_v", "middle.5.conv_block.5.weight_v", "middle.6.conv_block.1.weight_v", "middle.6.conv_block.5.weight_v", "middle.7.conv_block.1.weight_v", "middle.7.conv_block.5.weight_v", "decoder.0.weight_v", "decoder.3.weight_v".

Hi, someone knows how to solve this problem? thx!

AlvinWen428 commented 4 years ago

I solve this problem by updating my pytorch to 1.3.0. Hope this will help you.

goldservice2017 commented 4 years ago

Hi, @AlvinWen428 I got similar error with Pytorch 1.3.0. I trained custom edge/inpaint model with places2 dataset. When I try to use my trained model via test.py, it causing this error. Please help me.

python test.py --checkpoints ./checkpoints/recent --input ./examples/places2/images --mask ./examples/places2/masks --output ./results
/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/config.py:8: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.
  self._dict = yaml.load(self._yaml)
Loading EdgeModel generator...
Traceback (most recent call last):
  File "test.py", line 2, in <module>
    main(mode=2)
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/main.py", line 49, in main
    model.load()
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/edge_connect.py", line 61, in load
    self.edge_model.load()
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/models.py", line 31, in load
    self.generator.load_state_dict(data['generator'])
  File "/Users/mac/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for EdgeGenerator:
        Missing key(s) in state_dict: "encoder.1.weight_orig", "encoder.1.weight", "encoder.1.weight_u", "encoder.1.bias", "encoder.1.weight_orig", "encoder.1.weight_u", "encoder.1.weight_v", "encoder.4.weight_orig", "encoder.4.weight", "encoder.4.weight_u", "encoder.4.bias", "encoder.4.weight_orig", "encoder.4.weight_u", "encoder.4.weight_v", "encoder.7.weight_orig", "encoder.7.weight", "encoder.7.weight_u", "encoder.7.bias", "encoder.7.weight_orig", "encoder.7.weight_u", "encoder.7.weight_v", "middle.0.conv_block.1.weight_orig", "middle.0.conv_block.1.weight", "middle.0.conv_block.1.weight_u", "middle.0.conv_block.1.weight_orig", "middle.0.conv_block.1.weight_u", "middle.0.conv_block.1.weight_v", "middle.0.conv_block.5.weight_orig", "middle.0.conv_block.5.weight", "middle.0.conv_block.5.weight_u", "middle.0.conv_block.5.weight_orig", "middle.0.conv_block.5.weight_u", "middle.0.conv_block.5.weight_v", "middle.1.conv_block.1.weight_orig", "middle.1.conv_block.1.weight", "middle.1.conv_block.1.weight_u", "middle.1.conv_block.1.weight_orig", "middle.1.conv_block.1.weight_u", "middle.1.conv_block.1.weight_v", "middle.1.conv_block.5.weight_orig", "middle.1.conv_block.5.weight", "middle.1.conv_block.5.weight_u", "middle.1.conv_block.5.weight_orig", "middle.1.conv_block.5.weight_u", "middle.1.conv_block.5.weight_v", "middle.2.conv_block.1.weight_orig", "middle.2.conv_block.1.weight", "middle.2.conv_block.1.weight_u", "middle.2.conv_block.1.weight_orig", "middle.2.conv_block.1.weight_u", "middle.2.conv_block.1.weight_v", "middle.2.conv_block.5.weight_orig", "middle.2.conv_block.5.weight", "middle.2.conv_block.5.weight_u", "middle.2.conv_block.5.weight_orig", "middle.2.conv_block.5.weight_u", "middle.2.conv_block.5.weight_v", "middle.3.conv_block.1.weight_orig", "middle.3.conv_block.1.weight", "middle.3.conv_block.1.weight_u", "middle.3.conv_block.1.weight_orig", "middle.3.conv_block.1.weight_u", "middle.3.conv_block.1.weight_v", "middle.3.conv_block.5.weight_orig", "middle.3.conv_block.5.weight", "middle.3.conv_block.5.weight_u", "middle.3.conv_block.5.weight_orig", "middle.3.conv_block.5.weight_u", "middle.3.conv_block.5.weight_v", "middle.4.conv_block.1.weight_orig", "middle.4.conv_block.1.weight", "middle.4.conv_block.1.weight_u", "middle.4.conv_block.1.weight_orig", "middle.4.conv_block.1.weight_u", "middle.4.conv_block.1.weight_v", "middle.4.conv_block.5.weight_orig", "middle.4.conv_block.5.weight", "middle.4.conv_block.5.weight_u", "middle.4.conv_block.5.weight_orig", "middle.4.conv_block.5.weight_u", "middle.4.conv_block.5.weight_v", "middle.5.conv_block.1.weight_orig", "middle.5.conv_block.1.weight", "middle.5.conv_block.1.weight_u", "middle.5.conv_block.1.weight_orig", "middle.5.conv_block.1.weight_u", "middle.5.conv_block.1.weight_v", "middle.5.conv_block.5.weight_orig", "middle.5.conv_block.5.weight", "middle.5.conv_block.5.weight_u", "middle.5.conv_block.5.weight_orig", "middle.5.conv_block.5.weight_u", "middle.5.conv_block.5.weight_v", "middle.6.conv_block.1.weight_orig", "middle.6.conv_block.1.weight", "middle.6.conv_block.1.weight_u", "middle.6.conv_block.1.weight_orig", "middle.6.conv_block.1.weight_u", "middle.6.conv_block.1.weight_v", "middle.6.conv_block.5.weight_orig", "middle.6.conv_block.5.weight", "middle.6.conv_block.5.weight_u", "middle.6.conv_block.5.weight_orig", "middle.6.conv_block.5.weight_u", "middle.6.conv_block.5.weight_v", "middle.7.conv_block.1.weight_orig", "middle.7.conv_block.1.weight", "middle.7.conv_block.1.weight_u", "middle.7.conv_block.1.weight_orig", "middle.7.conv_block.1.weight_u", "middle.7.conv_block.1.weight_v", "middle.7.conv_block.5.weight_orig", "middle.7.conv_block.5.weight", "middle.7.conv_block.5.weight_u", "middle.7.conv_block.5.weight_orig", "middle.7.conv_block.5.weight_u", "middle.7.conv_block.5.weight_v", "decoder.0.weight_orig", "decoder.0.weight", "decoder.0.weight_u", "decoder.0.bias", "decoder.0.weight_orig", "decoder.0.weight_u", "decoder.0.weight_v", "decoder.3.weight_orig", "decoder.3.weight", "decoder.3.weight_u", "decoder.3.bias", "decoder.3.weight_orig", "decoder.3.weight_u", "decoder.3.weight_v", "decoder.7.weight", "decoder.7.bias". 
        Unexpected key(s) in state_dict: "module.encoder.1.bias", "module.encoder.1.weight_orig", "module.encoder.1.weight_u", "module.encoder.1.weight_v", "module.encoder.4.bias", "module.encoder.4.weight_orig", "module.encoder.4.weight_u", "module.encoder.4.weight_v", "module.encoder.7.bias", "module.encoder.7.weight_orig", "module.encoder.7.weight_u", "module.encoder.7.weight_v", "module.middle.0.conv_block.1.weight_orig", "module.middle.0.conv_block.1.weight_u", "module.middle.0.conv_block.1.weight_v", "module.middle.0.conv_block.5.weight_orig", "module.middle.0.conv_block.5.weight_u", "module.middle.0.conv_block.5.weight_v", "module.middle.1.conv_block.1.weight_orig", "module.middle.1.conv_block.1.weight_u", "module.middle.1.conv_block.1.weight_v", "module.middle.1.conv_block.5.weight_orig", "module.middle.1.conv_block.5.weight_u", "module.middle.1.conv_block.5.weight_v", "module.middle.2.conv_block.1.weight_orig", "module.middle.2.conv_block.1.weight_u", "module.middle.2.conv_block.1.weight_v", "module.middle.2.conv_block.5.weight_orig", "module.middle.2.conv_block.5.weight_u", "module.middle.2.conv_block.5.weight_v", "module.middle.3.conv_block.1.weight_orig", "module.middle.3.conv_block.1.weight_u", "module.middle.3.conv_block.1.weight_v", "module.middle.3.conv_block.5.weight_orig", "module.middle.3.conv_block.5.weight_u", "module.middle.3.conv_block.5.weight_v", "module.middle.4.conv_block.1.weight_orig", "module.middle.4.conv_block.1.weight_u", "module.middle.4.conv_block.1.weight_v", "module.middle.4.conv_block.5.weight_orig", "module.middle.4.conv_block.5.weight_u", "module.middle.4.conv_block.5.weight_v", "module.middle.5.conv_block.1.weight_orig", "module.middle.5.conv_block.1.weight_u", "module.middle.5.conv_block.1.weight_v", "module.middle.5.conv_block.5.weight_orig", "module.middle.5.conv_block.5.weight_u", "module.middle.5.conv_block.5.weight_v", "module.middle.6.conv_block.1.weight_orig", "module.middle.6.conv_block.1.weight_u", "module.middle.6.conv_block.1.weight_v", "module.middle.6.conv_block.5.weight_orig", "module.middle.6.conv_block.5.weight_u", "module.middle.6.conv_block.5.weight_v", "module.middle.7.conv_block.1.weight_orig", "module.middle.7.conv_block.1.weight_u", "module.middle.7.conv_block.1.weight_v", "module.middle.7.conv_block.5.weight_orig", "module.middle.7.conv_block.5.weight_u", "module.middle.7.conv_block.5.weight_v", "module.decoder.0.bias", "module.decoder.0.weight_orig", "module.decoder.0.weight_u", "module.decoder.0.weight_v", "module.decoder.3.bias", "module.decoder.3.weight_orig", "module.decoder.3.weight_u", "module.decoder.3.weight_v", "module.decoder.7.weight", "module.decoder.7.bias". 

Thank you

AlvinWen428 commented 4 years ago

Hi, @AlvinWen428 I got similar error with Pytorch 1.3.0. I trained custom edge/inpaint model with places2 dataset. When I try to use my trained model via test.py, it causing this error. Please help me.

python test.py --checkpoints ./checkpoints/recent --input ./examples/places2/images --mask ./examples/places2/masks --output ./results
/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/config.py:8: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.
  self._dict = yaml.load(self._yaml)
Loading EdgeModel generator...
Traceback (most recent call last):
  File "test.py", line 2, in <module>
    main(mode=2)
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/main.py", line 49, in main
    model.load()
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/edge_connect.py", line 61, in load
    self.edge_model.load()
  File "/Volumes/Work/2019_Work/Inpainting/origin_edge/edge-connect-origin/src/models.py", line 31, in load
    self.generator.load_state_dict(data['generator'])
  File "/Users/mac/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for EdgeGenerator:
        Missing key(s) in state_dict: "encoder.1.weight_orig", "encoder.1.weight", "encoder.1.weight_u", "encoder.1.bias", "encoder.1.weight_orig", "encoder.1.weight_u", "encoder.1.weight_v", "encoder.4.weight_orig", "encoder.4.weight", "encoder.4.weight_u", "encoder.4.bias", "encoder.4.weight_orig", "encoder.4.weight_u", "encoder.4.weight_v", "encoder.7.weight_orig", "encoder.7.weight", "encoder.7.weight_u", "encoder.7.bias", "encoder.7.weight_orig", "encoder.7.weight_u", "encoder.7.weight_v", "middle.0.conv_block.1.weight_orig", "middle.0.conv_block.1.weight", "middle.0.conv_block.1.weight_u", "middle.0.conv_block.1.weight_orig", "middle.0.conv_block.1.weight_u", "middle.0.conv_block.1.weight_v", "middle.0.conv_block.5.weight_orig", "middle.0.conv_block.5.weight", "middle.0.conv_block.5.weight_u", "middle.0.conv_block.5.weight_orig", "middle.0.conv_block.5.weight_u", "middle.0.conv_block.5.weight_v", "middle.1.conv_block.1.weight_orig", "middle.1.conv_block.1.weight", "middle.1.conv_block.1.weight_u", "middle.1.conv_block.1.weight_orig", "middle.1.conv_block.1.weight_u", "middle.1.conv_block.1.weight_v", "middle.1.conv_block.5.weight_orig", "middle.1.conv_block.5.weight", "middle.1.conv_block.5.weight_u", "middle.1.conv_block.5.weight_orig", "middle.1.conv_block.5.weight_u", "middle.1.conv_block.5.weight_v", "middle.2.conv_block.1.weight_orig", "middle.2.conv_block.1.weight", "middle.2.conv_block.1.weight_u", "middle.2.conv_block.1.weight_orig", "middle.2.conv_block.1.weight_u", "middle.2.conv_block.1.weight_v", "middle.2.conv_block.5.weight_orig", "middle.2.conv_block.5.weight", "middle.2.conv_block.5.weight_u", "middle.2.conv_block.5.weight_orig", "middle.2.conv_block.5.weight_u", "middle.2.conv_block.5.weight_v", "middle.3.conv_block.1.weight_orig", "middle.3.conv_block.1.weight", "middle.3.conv_block.1.weight_u", "middle.3.conv_block.1.weight_orig", "middle.3.conv_block.1.weight_u", "middle.3.conv_block.1.weight_v", "middle.3.conv_block.5.weight_orig", "middle.3.conv_block.5.weight", "middle.3.conv_block.5.weight_u", "middle.3.conv_block.5.weight_orig", "middle.3.conv_block.5.weight_u", "middle.3.conv_block.5.weight_v", "middle.4.conv_block.1.weight_orig", "middle.4.conv_block.1.weight", "middle.4.conv_block.1.weight_u", "middle.4.conv_block.1.weight_orig", "middle.4.conv_block.1.weight_u", "middle.4.conv_block.1.weight_v", "middle.4.conv_block.5.weight_orig", "middle.4.conv_block.5.weight", "middle.4.conv_block.5.weight_u", "middle.4.conv_block.5.weight_orig", "middle.4.conv_block.5.weight_u", "middle.4.conv_block.5.weight_v", "middle.5.conv_block.1.weight_orig", "middle.5.conv_block.1.weight", "middle.5.conv_block.1.weight_u", "middle.5.conv_block.1.weight_orig", "middle.5.conv_block.1.weight_u", "middle.5.conv_block.1.weight_v", "middle.5.conv_block.5.weight_orig", "middle.5.conv_block.5.weight", "middle.5.conv_block.5.weight_u", "middle.5.conv_block.5.weight_orig", "middle.5.conv_block.5.weight_u", "middle.5.conv_block.5.weight_v", "middle.6.conv_block.1.weight_orig", "middle.6.conv_block.1.weight", "middle.6.conv_block.1.weight_u", "middle.6.conv_block.1.weight_orig", "middle.6.conv_block.1.weight_u", "middle.6.conv_block.1.weight_v", "middle.6.conv_block.5.weight_orig", "middle.6.conv_block.5.weight", "middle.6.conv_block.5.weight_u", "middle.6.conv_block.5.weight_orig", "middle.6.conv_block.5.weight_u", "middle.6.conv_block.5.weight_v", "middle.7.conv_block.1.weight_orig", "middle.7.conv_block.1.weight", "middle.7.conv_block.1.weight_u", "middle.7.conv_block.1.weight_orig", "middle.7.conv_block.1.weight_u", "middle.7.conv_block.1.weight_v", "middle.7.conv_block.5.weight_orig", "middle.7.conv_block.5.weight", "middle.7.conv_block.5.weight_u", "middle.7.conv_block.5.weight_orig", "middle.7.conv_block.5.weight_u", "middle.7.conv_block.5.weight_v", "decoder.0.weight_orig", "decoder.0.weight", "decoder.0.weight_u", "decoder.0.bias", "decoder.0.weight_orig", "decoder.0.weight_u", "decoder.0.weight_v", "decoder.3.weight_orig", "decoder.3.weight", "decoder.3.weight_u", "decoder.3.bias", "decoder.3.weight_orig", "decoder.3.weight_u", "decoder.3.weight_v", "decoder.7.weight", "decoder.7.bias". 
        Unexpected key(s) in state_dict: "module.encoder.1.bias", "module.encoder.1.weight_orig", "module.encoder.1.weight_u", "module.encoder.1.weight_v", "module.encoder.4.bias", "module.encoder.4.weight_orig", "module.encoder.4.weight_u", "module.encoder.4.weight_v", "module.encoder.7.bias", "module.encoder.7.weight_orig", "module.encoder.7.weight_u", "module.encoder.7.weight_v", "module.middle.0.conv_block.1.weight_orig", "module.middle.0.conv_block.1.weight_u", "module.middle.0.conv_block.1.weight_v", "module.middle.0.conv_block.5.weight_orig", "module.middle.0.conv_block.5.weight_u", "module.middle.0.conv_block.5.weight_v", "module.middle.1.conv_block.1.weight_orig", "module.middle.1.conv_block.1.weight_u", "module.middle.1.conv_block.1.weight_v", "module.middle.1.conv_block.5.weight_orig", "module.middle.1.conv_block.5.weight_u", "module.middle.1.conv_block.5.weight_v", "module.middle.2.conv_block.1.weight_orig", "module.middle.2.conv_block.1.weight_u", "module.middle.2.conv_block.1.weight_v", "module.middle.2.conv_block.5.weight_orig", "module.middle.2.conv_block.5.weight_u", "module.middle.2.conv_block.5.weight_v", "module.middle.3.conv_block.1.weight_orig", "module.middle.3.conv_block.1.weight_u", "module.middle.3.conv_block.1.weight_v", "module.middle.3.conv_block.5.weight_orig", "module.middle.3.conv_block.5.weight_u", "module.middle.3.conv_block.5.weight_v", "module.middle.4.conv_block.1.weight_orig", "module.middle.4.conv_block.1.weight_u", "module.middle.4.conv_block.1.weight_v", "module.middle.4.conv_block.5.weight_orig", "module.middle.4.conv_block.5.weight_u", "module.middle.4.conv_block.5.weight_v", "module.middle.5.conv_block.1.weight_orig", "module.middle.5.conv_block.1.weight_u", "module.middle.5.conv_block.1.weight_v", "module.middle.5.conv_block.5.weight_orig", "module.middle.5.conv_block.5.weight_u", "module.middle.5.conv_block.5.weight_v", "module.middle.6.conv_block.1.weight_orig", "module.middle.6.conv_block.1.weight_u", "module.middle.6.conv_block.1.weight_v", "module.middle.6.conv_block.5.weight_orig", "module.middle.6.conv_block.5.weight_u", "module.middle.6.conv_block.5.weight_v", "module.middle.7.conv_block.1.weight_orig", "module.middle.7.conv_block.1.weight_u", "module.middle.7.conv_block.1.weight_v", "module.middle.7.conv_block.5.weight_orig", "module.middle.7.conv_block.5.weight_u", "module.middle.7.conv_block.5.weight_v", "module.decoder.0.bias", "module.decoder.0.weight_orig", "module.decoder.0.weight_u", "module.decoder.0.weight_v", "module.decoder.3.bias", "module.decoder.3.weight_orig", "module.decoder.3.weight_u", "module.decoder.3.weight_v", "module.decoder.7.weight", "module.decoder.7.bias". 

Thank you

According to my experience, this error is not caused by pytorch version but caused by multi-GPU. My suggestion is to check whether you use torch.nn.DataParallel() in the pre-train part or loading model code.

AlvinWen428 commented 4 years ago

@goldservice2017 An easy way to solve this issue is manually removing "module." from the keys of state_dict. Hope this will help you.

goldservice2017 commented 4 years ago

Thank you @AlvinWen428. Do you mean I can resolve this problem easily by following work? 1). Skip load_state_dict for keys in Missing key(s). 2). Manually delete unexpected key(s)

In my case, data['generator'] is including 70 keys and these are in unexpected keys.

Also data['generator'] is not including any keys in Missing key list.

Would you let me know how I can adjust issue?

I trained this model with Pytorch 1.3.1 & GPU(4 gpus), but I need to use it on CPU.

Thank you.

AlvinWen428 commented 4 years ago

@goldservice2017 Not exactly. I mean you can manually modify the unexpected keys by deleting 'modules.'. Here is an example: for k, v in state_dict['state_dict'].items(): new_key = k[7:] new_state_dict[new_key] = v Maybe there are some more elegent ways, but you can try this way first :)

goldservice2017 commented 4 years ago

@AlvinWen428 , Thank you for your reply. I am facing key errors after I removed 'modules.'. The error was caused by torch/nn/utils/spectral_norm.py line 167.

There are my code with new_state_dict:

new_state_dict = {}
            for k, v in data['generator'].items():
                new_key = k[7:]
                new_state_dict[new_key] = v

            self.generator.load_state_dict(new_state_dict)

And there are the lines that causing errors:

# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
    # See docstring of SpectralNorm._version on the changes to spectral_norm.
    def __init__(self, fn):
        self.fn = fn

    # For state_dict with version None, (assuming that it has gone through at
    # least one training forward), we have
    #
    #    u = normalize(W_orig @ v)
    #    W = W_orig / sigma, where sigma = u @ W_orig @ v
    #
    # To compute `v`, we solve `W_orig @ x = u`, and let
    #    v = x / (u @ W_orig @ x) * (W / W_orig).
    def __call__(self, state_dict, prefix, local_metadata, strict,
                 missing_keys, unexpected_keys, error_msgs):
        fn = self.fn
        version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
        if version is None or version < 1:
            with torch.no_grad():
                weight_orig = state_dict[prefix + fn.name + '_orig']
                weight = state_dict.pop(prefix + fn.name)
                sigma = (weight_orig / weight).mean()
                weight_mat = fn.reshape_weight_to_matrix(weight_orig)
                u = state_dict[prefix + fn.name + '_u']
                v = fn._solve_v_and_rescale(weight_mat, u, sigma)
                state_dict[prefix + fn.name + '_v'] = v

Here it can not get weight from weight = state_dict.pop(prefix + fn.name), because there is no such key in state_dict.

There are my new state_dict keys:

dict_keys(['encoder.1.bias', 'encoder.1.weight_orig', 'encoder.1.weight_u', 'encoder.1.weight_v', 
'encoder.4.bias', 'encoder.4.weight_orig', 'encoder.4.weight_u', 'encoder.4.weight_v', 'encoder.7.bias', 
'encoder.7.weight_orig', 'encoder.7.weight_u', 'encoder.7.weight_v', 'middle.0.conv_block.1.weight_orig', 
'middle.0.conv_block.1.weight_u', 'middle.0.conv_block.1.weight_v', 'middle.0.conv_block.5.weight_orig', 
'middle.0.conv_block.5.weight_u', 'middle.0.conv_block.5.weight_v', 'middle.1.conv_block.1.weight_orig', 
'middle.1.conv_block.1.weight_u', 'middle.1.conv_block.1.weight_v', 'middle.1.conv_block.5.weight_orig', 
'middle.1.conv_block.5.weight_u', 'middle.1.conv_block.5.weight_v', 'middle.2.conv_block.1.weight_orig', 
'middle.2.conv_block.1.weight_u', 'middle.2.conv_block.1.weight_v', 'middle.2.conv_block.5.weight_orig', 
'middle.2.conv_block.5.weight_u', 'middle.2.conv_block.5.weight_v', 'middle.3.conv_block.1.weight_orig', 
'middle.3.conv_block.1.weight_u', 'middle.3.conv_block.1.weight_v', 'middle.3.conv_block.5.weight_orig', 
'middle.3.conv_block.5.weight_u', 'middle.3.conv_block.5.weight_v', 'middle.4.conv_block.1.weight_orig', 
'middle.4.conv_block.1.weight_u', 'middle.4.conv_block.1.weight_v', 'middle.4.conv_block.5.weight_orig', 
'middle.4.conv_block.5.weight_u', 'middle.4.conv_block.5.weight_v', 'middle.5.conv_block.1.weight_orig', 
'middle.5.conv_block.1.weight_u', 'middle.5.conv_block.1.weight_v', 'middle.5.conv_block.5.weight_orig', 
'middle.5.conv_block.5.weight_u', 'middle.5.conv_block.5.weight_v', 'middle.6.conv_block.1.weight_orig', 
'middle.6.conv_block.1.weight_u', 'middle.6.conv_block.1.weight_v', 'middle.6.conv_block.5.weight_orig', 
'middle.6.conv_block.5.weight_u', 'middle.6.conv_block.5.weight_v', 'middle.7.conv_block.1.weight_orig', 
'middle.7.conv_block.1.weight_u', 'middle.7.conv_block.1.weight_v', 'middle.7.conv_block.5.weight_orig', 
'middle.7.conv_block.5.weight_u', 'middle.7.conv_block.5.weight_v', 'decoder.0.bias', 
'decoder.0.weight_orig', 'decoder.0.weight_u', 'decoder.0.weight_v', 'decoder.3.bias', 
'decoder.3.weight_orig', 'decoder.3.weight_u', 'decoder.3.weight_v', 'decoder.7.weight', 'decoder.7.bias'])

And there are keys which taking errors:

Key: 'encoder.1.weight'
Key: 'encoder.4.weight'
Key: 'encoder.7.weight'
Key: 'middle.0.conv_block.1.weight'
Key: 'middle.0.conv_block.5.weight'
Key: 'middle.1.conv_block.1.weight'
Key: 'middle.1.conv_block.5.weight'
Key: 'middle.2.conv_block.1.weight'
Key: 'middle.2.conv_block.5.weight'
Key: 'middle.3.conv_block.1.weight'
Key: 'middle.3.conv_block.5.weight'
Key: 'middle.4.conv_block.1.weight'
Key: 'middle.4.conv_block.5.weight'
Key: 'middle.5.conv_block.1.weight'
Key: 'middle.5.conv_block.5.weight'
Key: 'middle.6.conv_block.1.weight'
Key: 'middle.6.conv_block.5.weight'
Key: 'middle.7.conv_block.1.weight'
Key: 'middle.7.conv_block.5.weight'
Key: 'decoder.0.weight'
Key: 'decoder.3.weight'

How I can fix this issue? Perhaps, can I use _u or _v keys for fixing this issue?

Thank you

AlvinWen428 commented 4 years ago

@goldservice2017 I have also encountered this problem before and I fixed it by updating my pytorch version. You can check the versions of yo9ur python and pytorch.

goldservice2017 commented 4 years ago

@AlvinWen428 Thank you for your update. I trained model with Pytorch 1.0.1 and testing model with 1.3.0. After I updated Pytorch(1.3.1)/Torchvision(0.4.2), I resolved the issue. Thank you.

anshen666 commented 4 years ago

你好,我最近也在跑这个代码。可以加你交流一下吗?我的微信:loveanshen 我的QQ:519838354 我的邮箱:519838354@qq.com 非常期待你百忙中的回复

anshen666 commented 4 years ago

你好,我最近也在跑这个代码。可以加你交流一下吗?我的微信:loveanshen 我的QQ:519838354 我的邮箱:519838354@qq.com 非常期待你百忙中的回复

cats-food commented 4 years ago

@goldservice2017 I am confused about the 'weight_u' and 'weight_v' in conv. layer. Could you please tell me what do '_u' and '_v' means? thank you!

ghost commented 3 years ago

change the fucntion load in file src/models as follows and you can fix this problem:

    def load(self):
        if os.path.exists(self.gen_weights_path):
            print('Loading %s generator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.gen_weights_path)
            else:
                data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)

            # ======== modified areas =======================
            new_dict = {k:v for k,v in data['generator'].items() if k in self.generator.state_dict().keys()}
            diff_dict = {k[:-5]:v for k,v in data['generator'].items() if '_orig' in k}
            new_dict.update(diff_dict)
            self.generator.load_state_dict(new_dict)

            # self.generator.load_state_dict(data['generator'])
            self.iteration = data['iteration']

        # load discriminator only when training
        if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
            print('Loading %s discriminator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.dis_weights_path)
            else:
                data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)

            self.discriminator.load_state_dict(data['discriminator'])
MBaltz commented 2 years ago

To solve this:

image

~I just put the parameter strict=False in load_state_dict() function.~ ~Like this:~

self.TEDNet.load_state_dict(torch.load(f), strict=False)

~I think it's cause by the difference of the torch version (beetwen the saved weights and actual program [loader]).~


Correction!

You just need change the name of each item into OrderedDict (remove the module.). Example:

state_dict = torch.load("/state_dict/path")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    new_state_dict[k.replace("module.", "")] = v

YourModel.load_state_dict(new_state_dict)

In my case this problem was caused because the trained weights were saved using DataParallel and I tried to load without using DataParallel.

If you use the parameter strict=False in load_state_dict() function, may be your model will make wrong prediction.