xinntao / Real-ESRGAN

Real-ESRGAN aims at developing Practical Algorithms for General Image/Video Restoration.
BSD 3-Clause "New" or "Revised" License
27.64k stars 3.47k forks source link

Trying to run x2 model with 5ca1078 version causes RuntimeError in torch "size mismatch for conv_first.weight" #540

Open brunoais opened 1 year ago

brunoais commented 1 year ago

I'm trying to run this version: https://github.com/xinntao/Real-ESRGAN/commit/5ca1078535923d485892caee7d7804380bfc87fd With these dependency versions: torch==1.13.1 torchvision==0.14.1

I'm using these arguments:

env/bin/python \
./inference_realesrgan.py \
--input input.png \
--output output.png \
--suffix "suffix" \
--model_path "experiments/pretrained_models/RealESRGAN_x2plus.pth" \
--outscale 1 \
--tile 1000 \
--face_enhance \

I got the model (RealESRGAN_x2plus.pth) from: https://github.com/xinntao/Real-ESRGAN/releases/tag/v0.2.1 which is the latest version of a x2 model.

I get this error (same as #60):

RuntimeError: Error(s) in loading state_dict for RRDBNet:
    size mismatch for conv_first.weight: copying a param with shape torch.Size([64, 12, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).

I made this issue separate from #60 because I'm just trying to run Real-ESRGAN and not manipulating the model

What's wrong?

tumuyan commented 1 year ago

I have the same error with finetune x2 models, do you have solved the problem?

pth file:
https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth
https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth

config file:

scale: 2

path:
  # use the pre-trained Real-ESRNet model
  pretrain_network_g: experiments/pretrained_models/RealESRGAN_x2plus.pth
  param_key_g: params_ema
  strict_load_g: true
  pretrain_network_d: experiments/pretrained_models/RealESRGAN_x2plus_netD.pth
  param_key_d: params
  strict_load_d: true
  resume_state: ~

log:

2023-05-27 13:18:46,411 INFO: Loading RRDBNet model from experiments/pretrained_models/RealESRGAN_x2plus.pth, with param key: [params_ema].
Traceback (most recent call last):
  File "/content/Real-ESRGAN/realesrgan/train.py", line 11, in <module>
    train_pipeline(root_path)
  File "/usr/local/lib/python3.10/dist-packages/basicsr/train.py", line 124, in train_pipeline
    model = build_model(opt)
  File "/usr/local/lib/python3.10/dist-packages/basicsr/models/__init__.py", line 26, in build_model
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
  File "/content/Real-ESRGAN/realesrgan/models/realesrgan_model.py", line 24, in __init__
    super(RealESRGANModel, self).__init__(opt)
  File "/usr/local/lib/python3.10/dist-packages/basicsr/models/sr_model.py", line 30, in __init__
    self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
  File "/usr/local/lib/python3.10/dist-packages/basicsr/models/base_model.py", line 303, in load_network
    net.load_state_dict(load_net, strict=strict)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RRDBNet:
    size mismatch for conv_first.weight: copying a param with shape torch.Size([64, 12, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
[ ]
brunoais commented 1 year ago

Apparently, this is caused by wrong pair of torch and torchvision and/or too old torch or too new torch. I had to do some tip-toe checks on torch and torchvision to make sure I get a pair that works well enough.

You must use torch 1.X and a torchvision compatible with it.

I ended up with torch==1.10.0 and torchvision==0.10.1

See if that helps you.

tumuyan commented 1 year ago

thank you very much. I'm using touch 1.x, but I'm having trouble replacing it with 1.10. My friend told me that actually training and using the x2 model couldn‘t save a lot of time, so I give up training the x2 model

brunoais commented 1 year ago

Fine-tuning the 2x model doesn't speed up much but saves you if you don't have that much VRAM, so you can make more crisp or more accurate images by tiling less. Whatever works for you is the best option ^^

tumuyan commented 1 year ago

I feel there are other error. Running pip install torch==1.10.0 & pip install torchvision==0.10.1, there are no errors, but torch 1.9.1+cu102 and torchversion 0.10.1+cu102 are installed automatically.

trainning finetune_realesrgan_x2, still have same error

 Loading RRDBNet model from experiments/pretrained_models/RealESRGAN_x2plus.pth, with param key: [params_ema].
Traceback (most recent call last):
  File "realesrgan/train.py", line 11, in <module>
    train_pipeline(root_path)
  File "/opt/conda/lib/python3.7/site-packages/basicsr/train.py", line 124, in train_pipeline
    model = build_model(opt)
  File "/opt/conda/lib/python3.7/site-packages/basicsr/models/__init__.py", line 26, in build_model
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
  File "/root/Real-ESRGAN/realesrgan/models/realesrgan_model.py", line 24, in __init__
    super(RealESRGANModel, self).__init__(opt)
  File "/opt/conda/lib/python3.7/site-packages/basicsr/models/sr_model.py", line 30, in __init__
    self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
  File "/opt/conda/lib/python3.7/site-packages/basicsr/models/base_model.py", line 303, in load_network
    net.load_state_dict(load_net, strict=strict)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RRDBNet:
        size mismatch for conv_first.weight: copying a param with shape torch.Size([64, 12, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
tumuyan commented 1 year ago

60

I was using the file finetune_realesrgan_x4plus_pairdata.yml for finetuning and changed the scale to 2. But there is another scale in network_g which they didn't write it here. Once added and changed to 2, then it works

network_g:
  type: RRDBNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64
  num_block: 23
  num_grow_ch: 32
  scale: 2