NVIDIA / earth2mip

Earth-2 Model Intercomparison Project (MIP) is a python framework that enables climate researchers and scientists to inter-compare AI models for weather and climate.
https://nvidia.github.io/earth2mip/
Apache License 2.0
183 stars 40 forks source link

Run graphcast on the correct device #153

Closed nbren12 closed 8 months ago

nbren12 commented 8 months ago

Earth-2 MIP Pull Request

Description

When multiple GPUs are present graphcast was not running on the specified device. Jax places the model on the device of the input data it is passed. However, when converting the torch tensor to an xarray.Dataset of jax arrays the data were being cast to CPU numpy arrays. Using xarray_jax.Variable resolves this.

Checklist

Dependencies

nbren12 commented 8 months ago

/blossom-ci

nbren12 commented 8 months ago

I just verified this works by running the time_collection script on a selene node with 8 gpus and running nvidia-smi.

nbren12 commented 8 months ago

/blossom-ci