lzccccc / SMOKE

SMOKE: Single-Stage Monocular 3D Object Detection via Keypoint Estimation
MIT License
696 stars 177 forks source link

Loading the pretrained weights into pytorch #67

Open josh-wende opened 2 years ago

josh-wende commented 2 years ago

Hi,

I'm trying to use the pretrained weights provided in the readme, but am having some trouble. I loaded model_final.pth as a state_dict, and then tried to load it into a new instance of KeypointDetector, but this hit an error because the expected keys for a KeypointDetector state_dict and the given dictionary keys do not match up. What am I doing wrong?

Thanks, Josh

gch commented 2 years ago

It seems like the codebase may have changed since the pretrained weights were saved. I found the following (manually changed) key names gets things to work:

    pretrained = torch.load('/mnt/drive/MyDrive/Datasets/SMOKE/model_final.pth')['model']
    for k1, k2 in zip(sorted(pretrained.keys()), sorted(model.state_dict().keys())):
        assert(k1.replace('module.', '') == k2)
        model.state_dict()[k2][:] = pretrained[k1]

EDIT: reading the code further, there's just some extra information that's saved when the training script serializes the model (training debug data, etc.). It looks like https://github.com/lzccccc/SMOKE/blob/bc5d2bba66e2d66fa56b7b599d55457cb1a05b33/smoke/utils/model_serialization.py#L69 provides an existing helper function that loads from the slightly customized serialization.