tlpss / keypoint-detection

2D keypoint detection with Pytorch Lightning and wandb
MIT License
70 stars 9 forks source link

Creating `KeypointDetector` with old checkpoint fails with timm==0.9.0 #28

Open Victorlouisdg opened 1 year ago

Victorlouisdg commented 1 year ago

Downgrading timm to 0.6.13 solves the issue.

Code:

checkpoint_reference = "airo-box-manipulation/iros2022/model-14zb70au:v1"

run = wandb.init(project="inference", entity="airo-box-manipulation")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()
model_path = Path(artifact_dir) / "model.ckpt"
model = KeypointDetector.load_from_checkpoint(model_path, backbone=MaxVitUnet(), max_keypoints=4)

Full Error with timm 0.9.0:

Traceback (most recent call last):
  File "/home/victor/cloth-folding-iros-23/scripts/keypoints/01_live_detection.py", line 21, in <module>
    model = KeypointDetector.load_from_checkpoint(model_path, backbone=MaxVitUnet(), max_keypoints=4)
  File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 139, in load_from_checkpoint
    return _load_from_checkpoint(
  File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 188, in _load_from_checkpoint
    return _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 247, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
  File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for KeypointDetector:
        Missing key(s) in state_dict: "unnormalized_model.0.feature_extractor.stages.0.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.0.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.1.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.1.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.1.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.1.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.2.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.2.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.3.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.3.blocks.0.attn_grid.attn.rel_pos.relative_position_index". 
        Unexpected key(s) in state_dict: "unnormalized_model.0.feature_extractor._tensor_constant0", "unnormalized_model.0.feature_extractor._tensor_constant1", "unnormalized_model.0.feature_extractor._tensor_constant2", "unnormalized_model.0.feature_extractor._tensor_constant3", "unnormalized_model.0.feature_extractor._tensor_constant4", "unnormalized_model.0.feature_extractor._tensor_constant5", "unnormalized_model.0.feature_extractor._tensor_constant6", "unnormalized_model.0.feature_extractor._tensor_constant7", "unnormalized_model.0.feature_extractor._tensor_constant8", "unnormalized_model.0.feature_extractor._tensor_constant9", "unnormalized_model.0.feature_extractor._tensor_constant10", "unnormalized_model.0.feature_extractor._tensor_constant11", "unnormalized_model.0.feature_extractor._tensor_constant12", "unnormalized_model.0.feature_extractor._tensor_constant13", "unnormalized_model.0.feature_extractor._tensor_constant14", "unnormalized_model.0.feature_extractor._tensor_constant15", "unnormalized_model.0.feature_extractor._tensor_constant16", "unnormalized_model.0.feature_extractor._tensor_constant17", "unnormalized_model.0.feature_extractor._tensor_constant18", "unnormalized_model.0.feature_extractor._tensor_constant19", "unnormalized_model.0.feature_extractor._tensor_constant20", "unnormalized_model.0.feature_extractor._tensor_constant21", "unnormalized_model.0.feature_extractor._tensor_constant22", "unnormalized_model.0.feature_extractor._tensor_constant23", "unnormalized_model.0.feature_extractor._tensor_constant24", "unnormalized_model.0.feature_extractor._tensor_constant25", "unnormalized_model.0.feature_extractor._tensor_constant26", "unnormalized_model.0.feature_extractor._tensor_constant27".
tlpss commented 12 months ago

need to pin timm to >=0.9.X to avoid these issues.