weecology / DeepForest-pytorch

Pytorch implementation of the deepforest model for tree crown RGB detection.
MIT License
17 stars 9 forks source link

Loading multi-class saved model from checkpoint fails. #75

Closed bw4sz closed 3 years ago

bw4sz commented 3 years ago
>>> saved_model = main.deepforest.load_from_checkpoint("/orange/ewhite/b.weinstein/AerialDetection/snapshots/20210530_233702/DOTA.pl")
saved_model
/home/b.weinstein/miniconda3/envs/Zooniverse_pytorch/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729062494/work/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
Reading config file: deepforest_config.yml
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/b.weinstein/miniconda3/envs/Zooniverse_pytorch/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/b.weinstein/miniconda3/envs/Zooniverse_pytorch/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/home/b.weinstein/miniconda3/envs/Zooniverse_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for deepforest:
    size mismatch for model.head.classification_head.cls_logits.weight: copying a param with shape torch.Size([135, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([9, 256, 3, 3]).
    size mismatch for model.head.classification_head.cls_logits.bias: copying a param with shape torch.Size([135]) from checkpoint, the shape in current model is torch.Size([9]).
>>> saved_model
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'saved_model' is not defined
bw4sz commented 3 years ago

Covered in FAQ, simple work-around.

https://deepforest-pytorch.readthedocs.io/en/latest/FAQ.html#i-cannot-reload-a-saved-multi-class-model-using-a-checkpoint