microsoft / aurora

Implementation of the Aurora model for atmospheric forecasting
https://microsoft.github.io/aurora
Other
249 stars 31 forks source link

`AuroraHighRes` cannot be imported from `aurora` #18

Closed b8raoult closed 2 months ago

b8raoult commented 2 months ago

This means that the example at https://microsoft.github.io/aurora/models.html#aurora-0-1-fine-tuned does not work. I can import it form aurora.model.aurora.

sbetageri commented 2 months ago

How did you get the model weights to load? I got this error when running the snippet below.

from aurora.model.aurora import AuroraHighRes

model = AuroraHighRes()
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt")

The error that I get is below.

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[22], line 1
----> 1 model.load_checkpoint("microsoft[/aurora](http://127.0.0.1:8888/aurora)", "aurora-0.1-finetuned.ckpt")

File [~/code/aurora/aurora/model/aurora.py:240](http://127.0.0.1:8888/lab/workspaces/auto-L/tree/~/code/aurora/aurora/model/aurora.py#line=239), in Aurora.load_checkpoint(self, repo, name, strict)
    237         del d[k]
    238         d[k[4:]] = v
--> 240 self.load_state_dict(d, strict=strict)

File [~/miniconda3/envs/pai311/lib/python3.11/site-packages/torch/nn/modules/module.py:2189](http://127.0.0.1:8888/lab/workspaces/auto-L/tree/~/miniconda3/envs/pai311/lib/python3.11/site-packages/torch/nn/modules/module.py#line=2188), in Module.load_state_dict(self, state_dict, strict, assign)
   2184         error_msgs.insert(
   2185             0, 'Missing key(s) in state_dict: {}. '.format(
   2186                 ', '.join(f'"{k}"' for k in missing_keys)))
   2188 if len(error_msgs) > 0:
-> 2189     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Aurora:
    size mismatch for encoder.surf_token_embeds.weight: copying a param with shape torch.Size([512, 7, 2, 10, 10]) from checkpoint, the shape in current model is torch.Size([512, 7, 2, 4, 4]).
    size mismatch for encoder.atmos_token_embeds.weight: copying a param with shape torch.Size([512, 5, 2, 10, 10]) from checkpoint, the shape in current model is torch.Size([512, 5, 2, 4, 4]).
    size mismatch for decoder.surf_head.weight: copying a param with shape torch.Size([400, 1024]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
    size mismatch for decoder.surf_head.bias: copying a param with shape torch.Size([400]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for decoder.atmos_head.weight: copying a param with shape torch.Size([500, 1024]) from checkpoint, the shape in current model is torch.Size([80, 1024]).
    size mismatch for decoder.atmos_head.bias: copying a param with shape torch.Size([500]) from checkpoint, the shape in current model is torch.Size([80]).
wesselb commented 2 months ago

Hey both! Thanks for opening an issue. You're right that the high-res model is configured incorrectly. We'll fix this shortly and release and update for the package.