HanzhouLiu / DeblurDiNAT

Official implementation of the paper "DeblurDiNAT: A Lightweight and Effective Transformer for Image Deblurring".
https://arxiv.org/abs/2403.13163
Other
22 stars 1 forks source link

Unexpected key(s) in state_dict #4

Closed Ramesik closed 1 week ago

Ramesik commented 2 months ago

I used pretrained DeblurDiNATL.pth downloaded from link in ReadME I tried to run script on the mac and change line 34 to use cpu > ck = torch.load(weights_path, map_location=torch.device('cpu'))

predict_GoPro_test_results.py --weight_name=DeblurDiNATL.pth --blur_path=./inputData/DeblurDiNAT Traceback (most recent call last): File "predict_GoPro_test_results.py", line 38, in model.load_state_dict(ck) File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DataParallel: Unexpected key(s) in state_dict: "module.decoder.de_trans_level3.0.na2d.rpb", "module.decoder.de_trans_level3.1.na2d.rpb", "module.decoder.de_trans_level3.2.na2d.rpb", "module.decoder.de_trans_level3.3.na2d.rpb", "module.decoder.de_trans_level3.4.na2d.rpb", "module.decoder.de_trans_level3.5.na2d.rpb", "module.decoder.de_trans_level3.6.na2d.rpb", "module.decoder.de_trans_level3.7.na2d.rpb", "module.decoder.de_trans_level3.8.na2d.rpb", "module.decoder.de_trans_level3.9.na2d.rpb", "module.decoder.de_trans_level3.10.na2d.rpb", "module.decoder.de_trans_level3.11.na2d.rpb", "module.decoder.de_trans_level3.12.na2d.rpb", "module.decoder.de_trans_level3.13.na2d.rpb", "module.decoder.de_trans_level3.14.na2d.rpb", "module.decoder.de_trans_level3.15.na2d.rpb", "module.decoder.de_trans_level3.16.na2d.rpb", "module.decoder.de_trans_level3.17.na2d.rpb", "module.decoder.de_trans_level2.0.na2d.rpb", "module.decoder.de_trans_level2.1.na2d.rpb", "module.decoder.de_trans_level2.2.na2d.rpb", "module.decoder.de_trans_level2.3.na2d.rpb", "module.decoder.de_trans_level2.4.na2d.rpb", "module.decoder.de_trans_level2.5.na2d.rpb", "module.decoder.de_trans_level2.6.na2d.rpb", "module.decoder.de_trans_level2.7.na2d.rpb", "module.decoder.de_trans_level2.8.na2d.rpb", "module.decoder.de_trans_level2.9.na2d.rpb", "module.decoder.de_trans_level2.10.na2d.rpb", "module.decoder.de_trans_level2.11.na2d.rpb", "module.decoder.de_trans_level1.0.na2d.rpb", "module.decoder.de_trans_level1.1.na2d.rpb", "module.decoder.de_trans_level1.2.na2d.rpb", "module.decoder.de_trans_level1.3.na2d.rpb", "module.decoder.de_trans_level1.4.na2d.rpb", "module.decoder.de_trans_level1.5.na2d.rpb", "module.decoder.refinement.0.na2d.rpb", "module.decoder.refinement.1.na2d.rpb", "module.decoder.refinement.2.na2d.rpb", "module.decoder.refinement.3.na2d.rpb", "module.decoder.refinement.4.na2d.rpb", "module.decoder.refinement.5.na2d.rpb".

HanzhouLiu commented 2 months ago

I have not test it on a Mac yet. I did the test on a 3090 GPU and had no issue. I guess it is a cuda/DataParallel issue. I suggest you run it on a GPU rather than a CPU. Thank you!

Ramesik commented 1 month ago

I tried on GPU, install all components with cuda support, but get same error. Seems that .pth file is wrong or maybe not last code version in the repository?

Ramesik commented 1 month ago

If I use python 3.8 conda environment I got another error

Traceback (most recent call last):
  File "predict_GoPro_test_results.py", line 8, in <module>
    from models.networks import get_generator
  File "/workspace/DeblurDiNAT/models/networks.py", line 3, in <module>
    from models.DeblurDiNATL import NADeblurL
  File "/workspace/DeblurDiNAT/models/DeblurDiNATL.py", line 7, in <module>
    from natten import NeighborhoodAttention1D, NeighborhoodAttention2D
  File "/opt/conda/envs/DeblurDiNAT/lib/python3.8/site-packages/natten/__init__.py", line 24, in <module>
    from .context import (
  File "/opt/conda/envs/DeblurDiNAT/lib/python3.8/site-packages/natten/context.py", line 29, in <module>
    from .utils import log
  File "/opt/conda/envs/DeblurDiNAT/lib/python3.8/site-packages/natten/utils/__init__.py", line 24, in <module>
    from .checks import (
  File "/opt/conda/envs/DeblurDiNAT/lib/python3.8/site-packages/natten/utils/checks.py", line 27, in <module>
    from ..types import (
  File "/opt/conda/envs/DeblurDiNAT/lib/python3.8/site-packages/natten/types.py", line 36, in <module>
    DimensionType = Dimension1DType | Dimension2DType | Dimension3DType
TypeError: unsupported operand type(s) for |: '_GenericAlias' and '_GenericAlias'
HanzhouLiu commented 1 month ago

Sorry, I never met such an issue. I am not sure what is the root cause of that. It might be that natten was not installed correctly.