google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

How to solve TracerArrayConversionError (about xarray_jax)? #41

Closed AndrewYangnb closed 5 months ago

AndrewYangnb commented 6 months ago

When i run the shell code Loss computation (autoregressive loss over multiple steps) and Gradient computation (backprop through time)locally, i encountered the problem.

What is the cause of the problem? How should this problem be solved locally? Is there something wrong with my environment setup? Why is there no problem running in google colab? Oh, I have so many questions...

Below is the complete error log:

{
    "name": "TracerArrayConversionError",
    "message": "The numpy.ndarray conversion method __array__() was called on traced array with shape bfloat16[1].
The error occurred while tracing the function apply_fn at c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:440 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError",
    "stack": "---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[33], line 3
      1 # @title Loss computation (autoregressive loss over multiple steps)
      2 # 
----> 3 loss, diagnostics = loss_fn_jitted(
      4     rng=jax.random.PRNGKey(0),
      5     inputs=train_inputs,
      6     targets=train_targets,
      7     forcings=train_forcings)
      8 print(\"Loss:\", float(loss))

Cell In[29], line 68, in drop_state.<locals>.<lambda>(**kw)
     67 def drop_state(fn):
---> 68   return lambda **kw: fn(**kw)[0]

    [... skipping hidden 12 frame]

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:456, in transform_with_state.<locals>.apply_fn(params, state, rng, *args, **kwargs)
    454 with base.new_context(params=params, state=state, rng=rng) as ctx:
    455   try:
--> 456     out = f(*args, **kwargs)
    457   except jax.errors.UnexpectedTracerError as e:
    458     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

Cell In[29], line 37, in loss_fn(model_config, task_config, inputs, targets, forcings)
     34 @hk.transform_with_state
     35 def loss_fn(model_config, task_config, inputs, targets, forcings):
     36   predictor = construct_wrapped_graphcast(model_config, task_config)
---> 37   loss, diagnostics = predictor.loss(inputs, targets, forcings)
     38   return xarray_tree.map_structure(
     39       lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
     40       (loss, diagnostics))

File d:\\code\\graphcast-0.1\\graphcast\\autoregressive.py:236, in Predictor.loss(self, inputs, targets, forcings, **kwargs)
    230 \"\"\"The mean of the per-timestep losses of the underlying predictor.\"\"\"
    231 if targets.sizes['time'] == 1:
    232   # If there is only a single target timestep then we don't need any
    233   # autoregressive feedback and can delegate the loss directly to the
    234   # underlying single-step predictor. This means the underlying predictor
    235   # doesn't need to implement .loss_and_predictions.
--> 236   return self._predictor.loss(inputs, targets, forcings, **kwargs)
    238 constant_inputs = self._get_and_validate_constant_inputs(
    239     inputs, targets, forcings)
    240 self._validate_targets_and_forcings(targets, forcings)

File d:\\code\\graphcast-0.1\\graphcast\
ormalization.py:174, in InputsAndResiduals.loss(self, inputs, targets, forcings, **kwargs)
    170 norm_forcings = normalize(forcings, self._scales, self._locations)
    171 norm_target_residuals = xarray_tree.map_structure(
    172     lambda t: self._subtract_input_and_normalize_target(inputs, t),
    173     targets)
--> 174 return self._predictor.loss(
    175     norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)

File d:\\code\\graphcast-0.1\\graphcast\\casting.py:77, in Bfloat16Cast.loss(self, inputs, targets, forcings, **kwargs)
     74   return self._predictor.loss(inputs, targets, forcings, **kwargs)
     76 with bfloat16_variable_view():
---> 77   loss, scalars = self._predictor.loss(
     78       *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
     80 if loss.dtype != jnp.bfloat16:
     81   raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')

File d:\\code\\graphcast-0.1\\graphcast\\graphcast.py:424, in GraphCast.loss(self, inputs, targets, forcings)
    418 def loss(  # pytype: disable=signature-mismatch  # jax-ndarray
    419     self,
    420     inputs: xarray.Dataset,
    421     targets: xarray.Dataset,
    422     forcings: xarray.Dataset,
    423     ) -> predictor_base.LossAndDiagnostics:
--> 424   loss, _ = self.loss_and_predictions(inputs, targets, forcings)
    425   return loss

File d:\\code\\graphcast-0.1\\graphcast\\graphcast.py:400, in GraphCast.loss_and_predictions(self, inputs, targets, forcings)
    397 predictions = self(
    398     inputs, targets_template=targets, forcings=forcings, is_training=True)
    399 # Compute loss.
--> 400 loss = losses.weighted_mse_per_level(
    401     predictions, targets,
    402     per_variable_weights={
    403         # Any variables not specified here are weighted as 1.0.
    404         # A single-level variable, but an important headline variable
    405         # and also one which we have struggled to get good performance
    406         # on at short lead times, so leaving it weighted at 1.0, equal
    407         # to the multi-level variables:
    408         \"2m_temperature\": 1.0,
    409         # New single-level variables, which we don't weight too highly
    410         # to avoid hurting performance on other variables.
    411         \"10m_u_component_of_wind\": 0.1,
    412         \"10m_v_component_of_wind\": 0.1,
    413         \"mean_sea_level_pressure\": 0.1,
    414         \"total_precipitation_6hr\": 0.1,
    415     })
    416 return loss, predictions

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:69, in weighted_mse_per_level(predictions, targets, per_variable_weights)
     66     loss *= normalized_level_weights(target).astype(loss.dtype)
     67   return _mean_preserving_batch(loss)
---> 69 losses = xarray_tree.map_structure(loss, predictions, targets)
     70 return sum_per_variable_losses(losses, per_variable_weights)

File d:\\code\\graphcast-0.1\\graphcast\\xarray_tree.py:56, in map_structure(func, *structures)
     54 first = structures[0]
     55 if isinstance(first, xarray.Dataset):
---> 56   data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
     57   if all(isinstance(a, (type(None), xarray.DataArray))
     58          for a in data.values()):
     59     data_arrays = [v.rename(k) for k, v in data.items() if v is not None]

File d:\\code\\graphcast-0.1\\graphcast\\xarray_tree.py:56, in <dictcomp>(.0)
     54 first = structures[0]
     55 if isinstance(first, xarray.Dataset):
---> 56   data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
     57   if all(isinstance(a, (type(None), xarray.DataArray))
     58          for a in data.values()):
     59     data_arrays = [v.rename(k) for k, v in data.items() if v is not None]

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:67, in weighted_mse_per_level.<locals>.loss(prediction, target)
     65 if 'level' in target.dims:
     66   loss *= normalized_level_weights(target).astype(loss.dtype)
---> 67 return _mean_preserving_batch(loss)

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:74, in _mean_preserving_batch(x)
     73 def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
---> 74   return x.mean([d for d in x.dims if d != 'batch'], skipna=False)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\_aggregations.py:1663, in DataArrayAggregations.mean(self, dim, skipna, keep_attrs, **kwargs)
   1588 def mean(
   1589     self,
   1590     dim: Dims = None,
   (...)
   1594     **kwargs: Any,
   1595 ) -> Self:
   1596     \"\"\"
   1597     Reduce this DataArray's data by applying ``mean`` along some dimension(s).
   1598 
   (...)
   1661     array(nan)
   1662     \"\"\"
-> 1663     return self.reduce(
   1664         duck_array_ops.mean,
   1665         dim=dim,
   1666         skipna=skipna,
   1667         keep_attrs=keep_attrs,
   1668         **kwargs,
   1669     )

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\dataarray.py:3760, in DataArray.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   3716 def reduce(
   3717     self,
   3718     func: Callable[..., Any],
   (...)
   3724     **kwargs: Any,
   3725 ) -> Self:
   3726     \"\"\"Reduce this array by applying `func` along some dimension(s).
   3727 
   3728     Parameters
   (...)
   3757         summarized data and the indicated dimension(s) removed.
   3758     \"\"\"
-> 3760     var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   3761     return self._replace_maybe_drop_dims(var)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\variable.py:1756, in Variable.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   1749 keep_attrs_ = (
   1750     _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
   1751 )
   1753 # Noe that the call order for Variable.mean is
   1754 #    Variable.mean -> NamedArray.mean -> Variable.reduce
   1755 #    -> NamedArray.reduce
-> 1756 result = super().reduce(
   1757     func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
   1758 )
   1760 # return Variable always to support IndexVariable
   1761 return Variable(
   1762     result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
   1763 )

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\
amedarray\\core.py:789, in NamedArray.reduce(self, func, dim, axis, keepdims, **kwargs)
    784         dims = tuple(
    785             adim for n, adim in enumerate(self.dims) if n not in removed_axes
    786         )
    788 # Return NamedArray to handle IndexVariable when data is nD
--> 789 return from_array(dims, data, attrs=self._attrs)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\
amedarray\\core.py:203, in from_array(dims, data, attrs)
    200     return NamedArray(dims, to_0d_object_array(data), attrs)
    202 # validate whether the data is valid data types.
--> 203 return NamedArray(dims, np.asarray(data), attrs)

File d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:468, in JaxArrayWrapper.__array__(self, dtype, context)
    467 def __array__(self, dtype=None, context=None):
--> 468   return np.asarray(self.jax_array, dtype=dtype)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\jax\\_src\\core.py:668, in Tracer.__array__(self, *args, **kw)
    667 def __array__(self, *args, **kw):
--> 668   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape bfloat16[1].
The error occurred while tracing the function apply_fn at c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:440 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError"
}
mjwillson commented 5 months ago

Hello, thanks for the report. This should be fixed by https://github.com/google-deepmind/graphcast/commit/8debd7289bb2c498485f79dbd98d8b4933bfc6a7, if you're able to verify that'd be great, as the problem is with a newer version of xarray than we use internally.

ibrahimkaya754 commented 5 months ago

Hello, I was getting the same error, but it is fixed now. Thanks for the support

mjwillson commented 5 months ago

Thanks. OK will close this but do re-open if it shows up again.