ecmwf-lab / ai-models-graphcast

Apache License 2.0
64 stars 19 forks source link

Trying to run graphcast with the following command: ai-models --input cds --date 20231001 --time 0000 graphcast #12

Open tbwcy opened 7 months ago

tbwcy commented 7 months ago

but I get the following error: Traceback (most recent call last): File "/home/wcy/anaconda3/envs/largecpu/bin/ai-models", line 8, in sys.exit(main()) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 322, in main _main(sys.argv[1:]) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 270, in _main run(vars(args), unknownargs) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 274, in run model = load_model(cfg["model"], cfg, model_args=model_args) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/model.py", line 480, in load_model return available_models()[name].load()(kwargs) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/entrypoints.py", line 79, in load mod = import_module(self.module_name) File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/importlib/init.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "", line 1050, in _gcd_import File "", line 1027, in _find_and_load File "", line 1006, in _find_and_load_unlocked File "", line 688, in _load_unlocked File "", line 883, in exec_module File "", line 241, in _call_with_frames_removed File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 27, in import haiku as hk File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/init.py", line 20, in from haiku import experimental File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/_src/dot.py", line 163, in @jax.linear_util.transformation File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax' has no attribute 'linear_util'

msgomez06 commented 7 months ago

I don't know if you managed to solve this since you wrote this 5 days ago, but it turns out that this is an issue with the version of jax (they have updated the cuda-11 version and linear_util was moved to somewhere else in the library. In order to address this, you can modify the requirements-gpu.txt to:

jax[cuda11_pip]==0.4.23
git+https://github.com/deepmind/graphcast.git