google-deepmind / graphcast

Apache License 2.0
4.37k stars 538 forks source link

JAX error while training #30

Closed illuSION-crypto closed 5 months ago

illuSION-crypto commented 7 months ago

Hello, I wonder if anyone encountered the same question. When I try the graphcast_demo.ipynb, every step is okay until training, and I see this: image image I use the example data given in the google cloud bucket,and change nothing about code, anyone know how to solve it? My jax version is 0.4.20 with cuda 11.8

AndrewYangnb commented 6 months ago

I meet the same Error, I also want to know the reason and how to fix it.

ChrisAGBlake commented 6 months ago

Me too. I'm on jax 0.4.20 and cuda 12.3 using an nvidia GPU (4090).

ChrisAGBlake commented 6 months ago

I managed to fix this issue by installing an older version of xarray. I had version 2023.12.0 installed and when I downgraded to 2023.7.0 to replicate the colab example it worked.

pip uninstall xarray
pip install xarray==2023.7.0
illuSION-crypto commented 6 months ago

I managed to fix this issue by installing an older version of xarray. I had version 2023.12.0 installed and when I downgraded to 2023.7.0 to replicate the colab example it worked.

pip uninstall xarray
pip install xarray==2023.7.0

That's great, I also solved this issue by this way, thank you!

mjwillson commented 5 months ago

After this commit: https://github.com/google-deepmind/graphcast/commit/8debd7289bb2c498485f79dbd98d8b4933bfc6a7 we should be compatible with more recent versions of xarray too. Closing but please re-open if you still see the issue running from HEAD.