vishalned / MMEarth-train

This repository contains code to reproduce the experiments in the preprint "MMEarth: Exploring Multi-Modal Pretext Tasks For Geospatial Representation Learning"
https://vishalned.github.io/mmearth/
Other
24 stars 4 forks source link

Trouble loading a checkpoint #6

Closed antofuller closed 2 weeks ago

antofuller commented 2 weeks ago

Hi, great work with MMEarth!

I am trying to load an MMEarth pretrained state dictionary but am having some issues — including when using the provided example:

I'm using the weights stored in pt-all_mod_atto_1M_128_uncertainty_112-16. The first issue that throws an error (even with strict=False) are shape mismatches between the state dictionary and the initialized model. But I think they can be fixed via:

model_state_dict = {}
for key, value in checkpoint["model"].items():
    if ("grn" in key) and ("encoder" in key):
        value = value[None, None, ...]
    if ("bias" in key) and ("encoder" in key):
        value = value.squeeze(0)
    model_state_dict[key] = value

With these reshaped weights, torch.load now warns of many missing keys:

_IncompatibleKeys(missing_keys=['encoder.downsample_layers.0.0.weight', 'encoder.downsample_layers.0.0.bias', 'encoder.downsample_layers.0.1.weight', 'encoder.downsample_layers.1.0.weight', 'encoder.downsample_layers.1.0.bias', 'encoder.downsample_layers.1.1.weight', 'encoder.downsample_layers.2.0.weight', 'encoder.downsample_layers.2.0.bias', 'encoder.downsample_layers.2.1.weight', 'encoder.initial_conv.0.weight', 'encoder.initial_conv.1.weight', 'encoder.initial_conv.1.bias', 'encoder.stem.0.weight', 'encoder.stem.1.weight', 'encoder.stem.1.bias', 'encoder.stages.0.0.dwconv.weight', 'encoder.stages.0.0.norm.weight', 'encoder.stages.0.0.norm.bias', 'encoder.stages.0.0.pwconv1.weight', 'encoder.stages.0.0.pwconv1.bias', 'encoder.stages.0.0.pwconv2.weight', 'encoder.stages.0.0.pwconv2.bias', 'encoder.stages.0.1.dwconv.weight', 'encoder.stages.0.1.norm.weight', 'encoder.stages.0.1.norm.bias', 'encoder.stages.0.1.pwconv1.weight', 'encoder.stages.0.1.pwconv1.bias', 'encoder.stages.0.1.pwconv2.weight', 'encoder.stages.0.1.pwconv2.bias', 'encoder.stages.1.0.dwconv.weight', 'encoder.stages.1.0.norm.weight', 'encoder.stages.1.0.norm.bias', 'encoder.stages.1.0.pwconv1.weight', 'encoder.stages.1.0.pwconv1.bias', 'encoder.stages.1.0.pwconv2.weight', 'encoder.stages.1.0.pwconv2.bias', 'encoder.stages.1.1.dwconv.weight', 'encoder.stages.1.1.norm.weight', 'encoder.stages.1.1.norm.bias', 'encoder.stages.1.1.pwconv1.weight', 'encoder.stages.1.1.pwconv1.bias', 'encoder.stages.1.1.pwconv2.weight', 'encoder.stages.1.1.pwconv2.bias', 'encoder.stages.2.0.dwconv.weight', 'encoder.stages.2.0.norm.weight', 'encoder.stages.2.0.norm.bias', 'encoder.stages.2.0.pwconv1.weight', 'encoder.stages.2.0.pwconv1.bias', 'encoder.stages.2.0.pwconv2.weight', 'encoder.stages.2.0.pwconv2.bias', 'encoder.stages.2.1.dwconv.weight', 'encoder.stages.2.1.norm.weight', 'encoder.stages.2.1.norm.bias', 'encoder.stages.2.1.pwconv1.weight', 'encoder.stages.2.1.pwconv1.bias', 'encoder.stages.2.1.pwconv2.weight', 'encoder.stages.2.1.pwconv2.bias', 'encoder.stages.2.2.dwconv.weight', 'encoder.stages.2.2.norm.weight', 'encoder.stages.2.2.norm.bias', 'encoder.stages.2.2.pwconv1.weight', 'encoder.stages.2.2.pwconv1.bias', 'encoder.stages.2.2.pwconv2.weight', 'encoder.stages.2.2.pwconv2.bias', 'encoder.stages.2.3.dwconv.weight', 'encoder.stages.2.3.norm.weight', 'encoder.stages.2.3.norm.bias', 'encoder.stages.2.3.pwconv1.weight', 'encoder.stages.2.3.pwconv1.bias', 'encoder.stages.2.3.pwconv2.weight', 'encoder.stages.2.3.pwconv2.bias', 'encoder.stages.2.4.dwconv.weight', 'encoder.stages.2.4.norm.weight', 'encoder.stages.2.4.norm.bias', 'encoder.stages.2.4.pwconv1.weight', 'encoder.stages.2.4.pwconv1.bias', 'encoder.stages.2.4.pwconv2.weight', 'encoder.stages.2.4.pwconv2.bias', 'encoder.stages.2.5.dwconv.weight', 'encoder.stages.2.5.norm.weight', 'encoder.stages.2.5.norm.bias', 'encoder.stages.2.5.pwconv1.weight', 'encoder.stages.2.5.pwconv1.bias', 'encoder.stages.2.5.pwconv2.weight', 'encoder.stages.2.5.pwconv2.bias', 'encoder.stages.3.0.dwconv.weight', 'encoder.stages.3.0.norm.weight', 'encoder.stages.3.0.norm.bias', 'encoder.stages.3.0.pwconv1.weight', 'encoder.stages.3.0.pwconv1.bias', 'encoder.stages.3.0.pwconv2.weight', 'encoder.stages.3.0.pwconv2.bias', 'encoder.stages.3.1.dwconv.weight', 'encoder.stages.3.1.norm.weight', 'encoder.stages.3.1.norm.bias', 'encoder.stages.3.1.pwconv1.weight', 'encoder.stages.3.1.pwconv1.bias', 'encoder.stages.3.1.pwconv2.weight', 'encoder.stages.3.1.pwconv2.bias', 'encoder.norm.weight', 'encoder.norm.bias', 'encoder.head.weight', 'encoder.head.bias'], unexpected_keys=['loss_fn.log_vars', 'encoder.downsample_layers.0.0.ln.weight', 'encoder.downsample_layers.0.0.ln.bias', 'encoder.downsample_layers.0.1.kernel', 'encoder.downsample_layers.1.0.ln.weight', 'encoder.downsample_layers.1.0.ln.bias', 'encoder.downsample_layers.1.1.kernel', 'encoder.downsample_layers.2.0.ln.weight', 'encoder.downsample_layers.2.0.ln.bias', 'encoder.downsample_layers.2.1.kernel', 'encoder.initial_conv.0.kernel', 'encoder.initial_conv.1.ln.weight', 'encoder.initial_conv.1.ln.bias', 'encoder.stem.0.kernel', 'encoder.stem.1.ln.weight', 'encoder.stem.1.ln.bias', 'encoder.stages.0.0.dwconv.kernel', 'encoder.stages.0.0.norm.ln.weight', 'encoder.stages.0.0.norm.ln.bias', 'encoder.stages.0.0.pwconv1.linear.weight', 'encoder.stages.0.0.pwconv1.linear.bias', 'encoder.stages.0.0.pwconv2.linear.weight', 'encoder.stages.0.0.pwconv2.linear.bias', 'encoder.stages.0.1.dwconv.kernel', 'encoder.stages.0.1.norm.ln.weight', 'encoder.stages.0.1.norm.ln.bias', 'encoder.stages.0.1.pwconv1.linear.weight', 'encoder.stages.0.1.pwconv1.linear.bias', 'encoder.stages.0.1.pwconv2.linear.weight', 'encoder.stages.0.1.pwconv2.linear.bias', 'encoder.stages.1.0.dwconv.kernel', 'encoder.stages.1.0.norm.ln.weight', 'encoder.stages.1.0.norm.ln.bias', 'encoder.stages.1.0.pwconv1.linear.weight', 'encoder.stages.1.0.pwconv1.linear.bias', 'encoder.stages.1.0.pwconv2.linear.weight', 'encoder.stages.1.0.pwconv2.linear.bias', 'encoder.stages.1.1.dwconv.kernel', 'encoder.stages.1.1.norm.ln.weight', 'encoder.stages.1.1.norm.ln.bias', 'encoder.stages.1.1.pwconv1.linear.weight', 'encoder.stages.1.1.pwconv1.linear.bias', 'encoder.stages.1.1.pwconv2.linear.weight', 'encoder.stages.1.1.pwconv2.linear.bias', 'encoder.stages.2.0.dwconv.kernel', 'encoder.stages.2.0.norm.ln.weight', 'encoder.stages.2.0.norm.ln.bias', 'encoder.stages.2.0.pwconv1.linear.weight', 'encoder.stages.2.0.pwconv1.linear.bias', 'encoder.stages.2.0.pwconv2.linear.weight', 'encoder.stages.2.0.pwconv2.linear.bias', 'encoder.stages.2.1.dwconv.kernel', 'encoder.stages.2.1.norm.ln.weight', 'encoder.stages.2.1.norm.ln.bias', 'encoder.stages.2.1.pwconv1.linear.weight', 'encoder.stages.2.1.pwconv1.linear.bias', 'encoder.stages.2.1.pwconv2.linear.weight', 'encoder.stages.2.1.pwconv2.linear.bias', 'encoder.stages.2.2.dwconv.kernel', 'encoder.stages.2.2.norm.ln.weight', 'encoder.stages.2.2.norm.ln.bias', 'encoder.stages.2.2.pwconv1.linear.weight', 'encoder.stages.2.2.pwconv1.linear.bias', 'encoder.stages.2.2.pwconv2.linear.weight', 'encoder.stages.2.2.pwconv2.linear.bias', 'encoder.stages.2.3.dwconv.kernel', 'encoder.stages.2.3.norm.ln.weight', 'encoder.stages.2.3.norm.ln.bias', 'encoder.stages.2.3.pwconv1.linear.weight', 'encoder.stages.2.3.pwconv1.linear.bias', 'encoder.stages.2.3.pwconv2.linear.weight', 'encoder.stages.2.3.pwconv2.linear.bias', 'encoder.stages.2.4.dwconv.kernel', 'encoder.stages.2.4.norm.ln.weight', 'encoder.stages.2.4.norm.ln.bias', 'encoder.stages.2.4.pwconv1.linear.weight', 'encoder.stages.2.4.pwconv1.linear.bias', 'encoder.stages.2.4.pwconv2.linear.weight', 'encoder.stages.2.4.pwconv2.linear.bias', 'encoder.stages.2.5.dwconv.kernel', 'encoder.stages.2.5.norm.ln.weight', 'encoder.stages.2.5.norm.ln.bias', 'encoder.stages.2.5.pwconv1.linear.weight', 'encoder.stages.2.5.pwconv1.linear.bias', 'encoder.stages.2.5.pwconv2.linear.weight', 'encoder.stages.2.5.pwconv2.linear.bias', 'encoder.stages.3.0.dwconv.kernel', 'encoder.stages.3.0.norm.ln.weight', 'encoder.stages.3.0.norm.ln.bias', 'encoder.stages.3.0.pwconv1.linear.weight', 'encoder.stages.3.0.pwconv1.linear.bias', 'encoder.stages.3.0.pwconv2.linear.weight', 'encoder.stages.3.0.pwconv2.linear.bias', 'encoder.stages.3.1.dwconv.kernel', 'encoder.stages.3.1.norm.ln.weight', 'encoder.stages.3.1.norm.ln.bias', 'encoder.stages.3.1.pwconv1.linear.weight', 'encoder.stages.3.1.pwconv1.linear.bias', 'encoder.stages.3.1.pwconv2.linear.weight', 'encoder.stages.3.1.pwconv2.linear.bias'])

Any help would be much appreciated! And I apologize if I'm doing something dumb :)

Thanks again, Anthony

vishalned commented 2 weeks ago

Hey, I dont think the squeeze is required, but maybe it would be great if you could share the error you got before you added your "fix".

antofuller commented 2 weeks ago

Thanks for the reply!

Here is the error without my "fix":

RuntimeError: Error(s) in loading state_dict for FCMAE:
    size mismatch for encoder.downsample_layers.0.1.bias: copying a param with shape torch.Size([1, 80]) from checkpoint, the shape in current model is torch.Size([80]).
    size mismatch for encoder.downsample_layers.1.1.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.downsample_layers.2.1.bias: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([320]).
    size mismatch for encoder.initial_conv.0.bias: copying a param with shape torch.Size([1, 40]) from checkpoint, the shape in current model is torch.Size([40]).
    size mismatch for encoder.stem.0.bias: copying a param with shape torch.Size([1, 40]) from checkpoint, the shape in current model is torch.Size([40]).
    size mismatch for encoder.stages.0.0.dwconv.bias: copying a param with shape torch.Size([1, 40]) from checkpoint, the shape in current model is torch.Size([40]).
    size mismatch for encoder.stages.0.0.grn.gamma: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 160]).
    size mismatch for encoder.stages.0.0.grn.beta: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 160]).
    size mismatch for encoder.stages.0.1.dwconv.bias: copying a param with shape torch.Size([1, 40]) from checkpoint, the shape in current model is torch.Size([40]).
    size mismatch for encoder.stages.0.1.grn.gamma: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 160]).
    size mismatch for encoder.stages.0.1.grn.beta: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 160]).
    size mismatch for encoder.stages.1.0.dwconv.bias: copying a param with shape torch.Size([1, 80]) from checkpoint, the shape in current model is torch.Size([80]).
    size mismatch for encoder.stages.1.0.grn.gamma: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 320]).
    size mismatch for encoder.stages.1.0.grn.beta: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 320]).
    size mismatch for encoder.stages.1.1.dwconv.bias: copying a param with shape torch.Size([1, 80]) from checkpoint, the shape in current model is torch.Size([80]).
    size mismatch for encoder.stages.1.1.grn.gamma: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 320]).
    size mismatch for encoder.stages.1.1.grn.beta: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 320]).
    size mismatch for encoder.stages.2.0.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.0.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.0.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.1.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.1.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.1.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.2.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.2.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.2.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.3.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.3.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.3.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.4.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.4.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.4.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.5.dwconv.bias: copying a param with shape torch.Size([1, 160]) from checkpoint, the shape in current model is torch.Size([160]).
    size mismatch for encoder.stages.2.5.grn.gamma: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.2.5.grn.beta: copying a param with shape torch.Size([1, 640]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 640]).
    size mismatch for encoder.stages.3.0.dwconv.bias: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([320]).
    size mismatch for encoder.stages.3.0.grn.gamma: copying a param with shape torch.Size([1, 1280]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1280]).
    size mismatch for encoder.stages.3.0.grn.beta: copying a param with shape torch.Size([1, 1280]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1280]).
    size mismatch for encoder.stages.3.1.dwconv.bias: copying a param with shape torch.Size([1, 320]) from checkpoint, the shape in current model is torch.Size([320]).
    size mismatch for encoder.stages.3.1.grn.gamma: copying a param with shape torch.Size([1, 1280]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1280]).
    size mismatch for encoder.stages.3.1.grn.beta: copying a param with shape torch.Size([1, 1280]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1280]).

And here is a Google Colab notebook to reproduce it: (edited)

I commented out all sparse functionality in models/fcmae.py for this to run in Colab.

Thanks again for the support :)

antofuller commented 2 weeks ago

I was able to fix the issue by using your remap_checkpoint_keys function to make the weights compatible with the non-sparse version of ConvNext-V2.