danbider / lightning-pose

Accelerated pose estimation and tracking using semi-supervised convolutional networks.
MIT License
223 stars 32 forks source link

Error loading checkpoint for vit_b_sam backbone #134

Closed themattinthehatt closed 3 months ago

themattinthehatt commented 3 months ago

When loading weights from a fine-tuned vit_b_sam backbone, if the fine-tuning frame size is not 1024x1024 the following error is raised:

RuntimeError: Error(s) in loading state_dict for HeatmapTracker:
    size mismatch for backbone.pos_embed: copying a param with shape torch.Size([1, 16, 16, 768]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 768]).

The problem:

  1. During training, the regular vit_b_sam backbone is constructed, which assumes an image shape of 1024x1024
  2. If the image size that we are fine-tuning on is not 1024x1024, the position embedding is automatically updated during training and the new weights are stored (and eventually saved)
  3. When loading the weights into a new model, the position embedding parameter assuming 1024x1024 is constructed, but the saved parameter assuming a different image size is loaded in (with the above error).

The solution: Instead of loading the state dict directly into the model using Model.load_from_checkpoint, this step needs to be broken into several parts:

  1. Initialize the model (which includes loading the SAM weights) - this will set the position embedding parameter in a way that assumes 1024x1024 images
  2. Manually update the position embedding parameter to match the desired fine-tune image size
  3. Load the weights from the checkpoint