ecmwf-lab / ai-models-graphcast

Apache License 2.0
64 stars 19 forks source link

Trying to run following orders ai-models --input cds --date 20230110 --time 0000 graphcast --assets graphcast --download-assets #14

Open ckn161487 opened 7 months ago

ckn161487 commented 7 months ago

I meet this error: jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/media/data2/stu2/anaconda3/envs/pre-graphcast/bin/ai-models", line 33, in sys.exit(load_entry_point('ai-models==0.4.3', 'console_scripts', 'ai-models')()) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/ai_models/main.py", line 322, in main _main(sys.argv[1:]) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/ai_models/main.py", line 270, in _main run(vars(args), unknownargs) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/ai_models/main.py", line 295, in run model.run() File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 241, in run rng=jax.random.PRNGKey(0), File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/random.py", line 190, in PRNGKey return _return_prng_keys(True, _key('PRNGKey', seed, impl)) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/random.py", line 152, in _key return prng.seed_with_impl(impl, seed) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/prng.py", line 413, in seed_with_impl return random_seed(seed, impl=impl) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/prng.py", line 695, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/core.py", line 386, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/core.py", line 389, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/core.py", line 821, in process_primitive return primitive.impl(*tracers, **params) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/prng.py", line 712, in random_seed_impl_base return seed(seeds) File "/media/data2/stu2/anaconda3/envs/pre-graphcast/lib/python3.10/site-packages/jax/_src/prng.py", line 941, in threefry_seed return _threefry_seed(seed) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get module function: CUDA_ERROR_NOT_FOUND: named symbol not found; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit__threefry_seed,program_id=1#. I0000 00:00:1711361332.558439 6111 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed. how to solve this error?

HCookie commented 1 week ago

Jax was updated earlier this year, see #12 for a fix