Open Victorlouisdg opened 1 year ago
Downgrading timm to 0.6.13 solves the issue.
Code:
checkpoint_reference = "airo-box-manipulation/iros2022/model-14zb70au:v1" run = wandb.init(project="inference", entity="airo-box-manipulation") artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() model_path = Path(artifact_dir) / "model.ckpt" model = KeypointDetector.load_from_checkpoint(model_path, backbone=MaxVitUnet(), max_keypoints=4)
Full Error with timm 0.9.0:
Traceback (most recent call last): File "/home/victor/cloth-folding-iros-23/scripts/keypoints/01_live_detection.py", line 21, in <module> model = KeypointDetector.load_from_checkpoint(model_path, backbone=MaxVitUnet(), max_keypoints=4) File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 247, in _load_state keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) File "/home/victor/anaconda3/envs/cloth-folding-iros-23/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for KeypointDetector: Missing key(s) in state_dict: "unnormalized_model.0.feature_extractor.stages.0.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.0.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.1.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.1.blocks.1.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.0.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.1.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.1.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.2.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.2.blocks.2.attn_grid.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.3.blocks.0.attn_block.attn.rel_pos.relative_position_index", "unnormalized_model.0.feature_extractor.stages.3.blocks.0.attn_grid.attn.rel_pos.relative_position_index". Unexpected key(s) in state_dict: "unnormalized_model.0.feature_extractor._tensor_constant0", "unnormalized_model.0.feature_extractor._tensor_constant1", "unnormalized_model.0.feature_extractor._tensor_constant2", "unnormalized_model.0.feature_extractor._tensor_constant3", "unnormalized_model.0.feature_extractor._tensor_constant4", "unnormalized_model.0.feature_extractor._tensor_constant5", "unnormalized_model.0.feature_extractor._tensor_constant6", "unnormalized_model.0.feature_extractor._tensor_constant7", "unnormalized_model.0.feature_extractor._tensor_constant8", "unnormalized_model.0.feature_extractor._tensor_constant9", "unnormalized_model.0.feature_extractor._tensor_constant10", "unnormalized_model.0.feature_extractor._tensor_constant11", "unnormalized_model.0.feature_extractor._tensor_constant12", "unnormalized_model.0.feature_extractor._tensor_constant13", "unnormalized_model.0.feature_extractor._tensor_constant14", "unnormalized_model.0.feature_extractor._tensor_constant15", "unnormalized_model.0.feature_extractor._tensor_constant16", "unnormalized_model.0.feature_extractor._tensor_constant17", "unnormalized_model.0.feature_extractor._tensor_constant18", "unnormalized_model.0.feature_extractor._tensor_constant19", "unnormalized_model.0.feature_extractor._tensor_constant20", "unnormalized_model.0.feature_extractor._tensor_constant21", "unnormalized_model.0.feature_extractor._tensor_constant22", "unnormalized_model.0.feature_extractor._tensor_constant23", "unnormalized_model.0.feature_extractor._tensor_constant24", "unnormalized_model.0.feature_extractor._tensor_constant25", "unnormalized_model.0.feature_extractor._tensor_constant26", "unnormalized_model.0.feature_extractor._tensor_constant27".
need to pin timm to >=0.9.X to avoid these issues.
Downgrading timm to 0.6.13 solves the issue.
Code:
Full Error with timm 0.9.0: