Fried-Rice-Lab / FriedRiceLab

Official repository of the Fried Rice Lab, including code resources of the following our works: ESWT [arXiv], etc. This repository also implements many useful features and out-of-the-box image restoration models.
MIT License
205 stars 31 forks source link

Error using denoising task #3

Closed camilo1704 closed 1 year ago

camilo1704 commented 1 year ago

Hi, first at all great work guys!! I'm trying to use the denoising task and using the following command:
python infer.py -expe_opt options/repr/ESWT/ESWT-12-12_LSR.yml -task_opt options/task/Denoising.yml got the next output:

Traceback (most recent call last):
  File "infer.py", line 61, in <module>
    infer_pipeline(root_path)
  File "infer.py", line 47, in infer_pipeline
    model = build_model(opt)
  File "/home/cuchuflito/.local/lib/python3.8/site-packages/basicsr/models/__init__.py", line 26, in build_model
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
  File "/home/cuchuflito/Documents/adaviv/notebook/image-denoising/FriedRiceLab/models/ir_model.py", line 26, in __init__
    super(IRModel, self).__init__(opt)
  File "/home/cuchuflito/.local/lib/python3.8/site-packages/basicsr/models/sr_model.py", line 30, in __init__
    self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
  File "/home/cuchuflito/.local/lib/python3.8/site-packages/basicsr/models/base_model.py", line 303, in load_network
    net.load_state_dict(load_net, strict=strict)
  File "/home/cuchuflito/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ESWT:
    size mismatch for tail.0.weight: copying a param with shape torch.Size([48, 60, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 60, 3, 3]).
    size mismatch for tail.0.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([3]).

Do you have any advice? Thank you in advance.

jnpngshiii commented 1 year ago

Hey! This is because the self.tail of the model used for denoising and super-resolution is different. So, just set strict_load_g to FLASE. Try the following command instead:

python infer.py -expe_opt options/repr/ESWT/ESWT-12-12_LSR.yml -task_opt options/task/Denoising.yml --force_yml path:strict_load_g=false

Note that this will replace self.tail with randomly initialized, untrained layers. You can use the SIDD dataset or your own dataset to fine-tune ESWT to get better denoising results.