NVIDIA / earth2mip

Earth-2 Model Intercomparison Project (MIP) is a python framework that enables climate researchers and scientists to inter-compare AI models for weather and climate.
Apache License 2.0
183 stars 40 forks source link

🐛[BUG]: e2mip://fcn broken #171

Closed nbren12 closed 6 months ago

nbren12 commented 7 months ago



On which installation method(s) does this occur?

No response

Describe the issue

Seems the checkpoint is not compatible with the newer versions of modulus:

Here is the error message

In [1]: import earth2mip.networks

In [2]: earth2mip.networks.get_model("e2mip://fcn")
TypeError                                 Traceback (most recent call last)
Cell In[2], line 1
----> 1 earth2mip.networks.get_model("e2mip://fcn")

File /usr/local/lib/python3.10/dist-packages/earth2mip/networks/__init__.py:345, in get_model(model, registry, device, metadata)
    343 if url.scheme == "e2mip":
    344     package = registry.get_model(model)
--> 345     return _load_package_builtin(package, device, name=url.netloc)
    346 elif url.scheme == "":
    347     package = registry.get_model(model)

File /usr/local/lib/python3.10/dist-packages/earth2mip/networks/__init__.py:291, in _load_package_builtin(package, device, name)
    289     if entry_point.name == name:
    290         inference_loader = entry_point.load()
--> 291         return inference_loader(package, device=device)
    292 raise ValueError(f"{name} not in {names_found}.")

File /usr/local/lib/python3.10/dist-packages/earth2mip/networks/fcn.py:64, in load(package, pretrained, device)
     61 local_center = np.load(package.get("global_means.npy"))
     62 local_std = np.load(package.get("global_stds.npy"))
---> 64 core_model = modulus.Module.from_checkpoint(package.get("fcn.mdlus"))
     66 dt = datetime.timedelta(hours=6)
     67 grid = earth2mip.grid.equiangular_lat_lon_grid(720, 1440, includes_south_pole=False)

File /usr/local/lib/python3.10/dist-packages/modulus/models/module.py:333, in Module.from_checkpoint(cls, file_name)
    331 with open(local_path.joinpath("args.json"), "r") as f:
    332     args = json.load(f)
--> 333 model = cls.instantiate(args)
    335 # Load the model weights
    336 model_dict = torch.load(
    337     local_path.joinpath("model.pt"), map_location=model.device
    338 )

File /usr/local/lib/python3.10/dist-packages/modulus/models/module.py:156, in Module.instantiate(cls, arg_dict)
    153     except AttributeError:
    154         # Cross fingers and hope for the best (maybe the class name changed)
    155         _cls = cls
--> 156 return _cls(**arg_dict["__args__"])

File /usr/local/lib/python3.10/dist-packages/modulus/models/module.py:59, in Module.__new__(cls, *args, **kwargs)
     57 out = super().__new__(cls)
     58 sig = inspect.signature(cls.__init__)
---> 59 bound_args = sig.bind_partial(
     60     *([None] + list(args)), **kwargs
     61 )  # Add None to account for self
     62 bound_args.apply_defaults()
     63 bound_args.arguments.pop("self", None)

File /usr/lib/python3.10/inspect.py:3186, in Signature.bind_partial(self, *args, **kwargs)
   3181 def bind_partial(self, /, *args, **kwargs):
   3182     """Get a BoundArguments object, that partially maps the
   3183     passed `args` and `kwargs` to the function's signature.
   3184     Raises `TypeError` if the passed arguments can not be bound.
   3185     """
-> 3186     return self._bind(args, kwargs, partial=True)

File /usr/lib/python3.10/inspect.py:3168, in Signature._bind(self, args, kwargs, partial)
   3166         arguments[kwargs_param.name] = kwargs
   3167     else:
-> 3168         raise TypeError(
   3169             'got an unexpected keyword argument {arg!r}'.format(
   3170                 arg=next(iter(kwargs))))
   3172 return self._bound_arguments_cls(self, arguments)

TypeError: got an unexpected keyword argument 'img_size'

Environment details

- nvidia-modulus-0.5.0
NickGeneva commented 6 months ago

Wasnt able to reproduce with Modulus on my machine:

>>> import modulus
>>> modulus.__version__
>>> from modulus.models.afno import AFNO
>>> model = AFNO.from_checkpoint("fcn.mdlus")
>>> model
  (patch_embed): PatchEmbed(
    (proj): Conv2d(26, 768, kernel_size=(8, 8), stride=(8, 8))
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (filter): AFNO2DLayer()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): AFNOMlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
  (head): Linear(in_features=768, out_features=1664, bias=False)

Could it be you have an old AFNO checkpoint in a model registry? I updated to a new version (auto download should pull this one) after a small API change occurred in Modulus. The v0.2 model has the fix for the error you're seeing I believe.

This got updated I think in 0.4.0 of modulus, interesting its just showing up now.