Open tbwcy opened 7 months ago
I don't know if you managed to solve this since you wrote this 5 days ago, but it turns out that this is an issue with the version of jax (they have updated the cuda-11 version and linear_util was moved to somewhere else in the library. In order to address this, you can modify the requirements-gpu.txt to:
jax[cuda11_pip]==0.4.23
git+https://github.com/deepmind/graphcast.git
but I get the following error: Traceback (most recent call last): File "/home/wcy/anaconda3/envs/largecpu/bin/ai-models", line 8, in
sys.exit(main())
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 322, in main
_main(sys.argv[1:])
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 270, in _main
run(vars(args), unknownargs)
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/main.py", line 274, in run
model = load_model(cfg["model"], cfg, model_args=model_args)
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models/model.py", line 480, in load_model
return available_models()[name].load()(kwargs)
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/entrypoints.py", line 79, in load
mod = import_module(self.module_name)
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/importlib/init.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "", line 1050, in _gcd_import
File "", line 1027, in _find_and_load
File "", line 1006, in _find_and_load_unlocked
File "", line 688, in _load_unlocked
File "", line 883, in exec_module
File "", line 241, in _call_with_frames_removed
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 27, in
import haiku as hk
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/init.py", line 20, in
from haiku import experimental
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in
from haiku._src.dot import abstract_to_dot
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/haiku/_src/dot.py", line 163, in
@jax.linear_util.transformation
File "/home/wcy/anaconda3/envs/largecpu/lib/python3.10/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'linear_util'