ecmwf-lab / ai-models-graphcast

Apache License 2.0
57 stars 19 forks source link

Graphcast fails at xarray_jax #2

Open EricLeer opened 10 months ago

EricLeer commented 10 months ago

Trying to run graphcast with the following command:

ai-models --input cds --date 20231001 --time 0000 graphcast

but I get the following error:

Traceback (most recent call last):
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/bin/ai-models", line 8, in <module>
    sys.exit(main())
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 291, in main
    _main()
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 264, in _main
    model.run()
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 237, in run
    output = self.model(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 114, in <lambda>
    return lambda **kw: fn(**kw)[0]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 255, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 485, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 962, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 348, in memoized_fun
    ans = call(fun, *args)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 915, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2203, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2225, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 190, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/haiku/_src/transform.py", line 457, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 168, in run_forward
    return predictor(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/autoregressive.py", line 169, in __call__
    target_template = targets_template.isel(time=[0])
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/dataset.py", line 2920, in isel
    var = var.isel(var_indexers)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 1135, in isel
    return self[key]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 811, in __getitem__
    data = as_indexable(self._data)[indexer]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/indexing.py", line 1336, in __getitem__
    return array[key]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/xarray_jax.py", line 352, in wrapped_func
    args, kwargs = tree.map_structure(unwrap, (args, kwargs))
AttributeError: module 'tree' has no attribute 'map_structure'

Running on python 3.10 with the following package versions:

ai-models           0.2.14
ai-models-graphcast 0.0.4
jax                 0.4.19
jaxlib              0.4.19

Any idea on what the problem might be? It seems to originate from xarray_jax which is trying to call an attribute that doesnt exist.

Dadoof commented 10 months ago

In my build, I did this: sudo pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Unsure how you installed Jax, but if you did not do the above - might be worth a try. This is only a guess, on my part.

mjwillson commented 8 months ago

Hello, it looks like this may relate to an issue where graphcast was pulling in the wrong tree library dependency. That should now be resolved as the following PR has been merged: https://github.com/google-deepmind/graphcast/pull/25 would you mind reinstalling graphcast from git and checking if you still see the problem?