I get the following shape mismatch error whenever I try to run the demo:
Disable distributed.
load net keys <built-in method keys of collections.OrderedDict object at 0x7f04c6da2540>
Traceback (most recent call last):
File "basicsr/demo.py", line 46, in
main()
File "basicsr/demo.py", line 40, in main
model = create_model(opt)
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/init.py", line 44, in create_model
model = model_cls(opt)
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/image_restoration_model.py", line 36, in init
self.load_network(self.net_g, load_path,
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/base_model.py", line 287, in load_network
net.load_state_dict(load_net, strict=strict)
File "/media/spoodermun/YoData/morph/morphology_detection/YOLO/pytorch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HINet:
size mismatch for down_path_1.0.identity.weight: copying a param with shape torch.Size([64, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 32, 1, 1]).
size mismatch for down_path_1.0.identity.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
size mismatch for down_path_1.0.conv_1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
I think the problem might because of the inconsistency between network structure and the pre-trained model,
and if you want to try pre-trained HINet-SIDD-1x, please refer to here for the network structure : )
I get the following shape mismatch error whenever I try to run the demo:
Disable distributed. load net keys <built-in method keys of collections.OrderedDict object at 0x7f04c6da2540> Traceback (most recent call last): File "basicsr/demo.py", line 46, in
main()
File "basicsr/demo.py", line 40, in main
model = create_model(opt)
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/init.py", line 44, in create_model
model = model_cls(opt)
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/image_restoration_model.py", line 36, in init
self.load_network(self.net_g, load_path,
File "/media/spoodermun/YoData/FISH/noise_mitigation/HINet/basicsr/models/base_model.py", line 287, in load_network
net.load_state_dict(load_net, strict=strict)
File "/media/spoodermun/YoData/morph/morphology_detection/YOLO/pytorch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HINet:
size mismatch for down_path_1.0.identity.weight: copying a param with shape torch.Size([64, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 32, 1, 1]).
size mismatch for down_path_1.0.identity.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
size mismatch for down_path_1.0.conv_1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
The demo config is as follows:
------------------------------------------------------------------------
Copyright (c) 2021 megvii-model. All Rights Reserved.
------------------------------------------------------------------------
Modified from BasicSR (https://github.com/xinntao/BasicSR)
Copyright 2018-2020 BasicSR Authors
------------------------------------------------------------------------
general settings
name: demo model_type: ImageRestorationModel scale: 1 num_gpu: 1 # set num_gpu: 0 for cpu mode manual_seed: 10
single image inference and save image
img_path: input_img: ./noise_mitigation/test_images/enhanced/1p36_0_0.png output_img: ./demo/demo1.png
network structures
network_g: type: HINet wf: 32 hin_position_left: 0 hin_position_right: 4
path
path: pretrain_network_g: ./noise_mitigation/HINet/HINet-SIDD-1x.pth strict_load_g: true resume_state: ~
validation settings
val: grids: true crop_size: 256
dist training settings
dist_params: backend: nccl port: 29500