jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.23k stars 2.77k forks source link

`jax` and `xarray` integration for automatic differentiation? #17107

Open tylerflex opened 1 year ago

tylerflex commented 1 year ago

I've been wondering if there has been any recent progress in integrating jax and xarray, specifically for automatic differentiation. For context, we have a simulation project that relies on xarray for our simulation output data but recently added jax support so users can automatically differentiate through these simulations. To make this work, we added code to emulate xr.DataArray functionality but with jax internals. However, this approach has been a headache to maintain and extend. It would be amazing if xarray had native support for gradient tracking in jax.

As an example, the code snippet below multiplies a Jax-traced value by an xarray.DataArray, does an interpolation, and then a jax-traved operation. It would be great if we could differentiate through this. The forward pass works, but the backwards pass gives a TracerArrayConversionError.

I've tried many other workarounds based on issues, such as this and some other discussions eg but without any luck. Are any updates on the status of this, whether it would be possible eventually, or suggestions for possible workarounds? Any discussion or pointers towards a good approach to this are really appreciated.

@shoyer

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr

shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length) for dim, length in zip('xyz', shape)}
xarr = xr.DataArray(values, coords=coords)

def f(x):
    xarr_multiplied = x * xarr
    val = xarr_multiplied.interp(x=1, y=1, z=1)
    return jnp.sqrt(val.values)

f(1.0)
jax.grad(f)(1.0)
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[30], line 17
     14     return jnp.sqrt(val.values)
     16 f(1.0)
---> 17 jax.grad(f)(1.0)

    [... skipping hidden 10 frame]

Cell In[30], line 12, in f(x)
     11 def f(x):
---> 12     xarr_multiplied = x * xarr
     13     val = xarr_multiplied.interp(x=1, y=1, z=1)
     14     return jnp.sqrt(val.values)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:282, in DataArrayOpsMixin.__rmul__(self, other)
    281 def __rmul__(self, other):
--> 282     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/dataarray.py:4622, in DataArray._binary_op(self, other, f, reflexive)
   4616 other_variable = getattr(other, "variable", other)
   4617 other_coords = getattr(other, "coords", None)
   4619 variable = (
   4620     f(self.variable, other_variable)
   4621     if not reflexive
-> 4622     else f(other_variable, self.variable)
   4623 )
   4624 coords, indexes = self.coords._merge_raw(other_coords, reflexive)
   4625 name = self._result_name(other)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:488, in VariableOpsMixin.__rmul__(self, other)
    487 def __rmul__(self, other):
--> 488     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:2707, in Variable._binary_op(self, other, f, reflexive)
   2703 with np.errstate(all="ignore"):
   2704     new_data = (
   2705         f(self_data, other_data) if not reflexive else f(other_data, self_data)
   2706     )
-> 2707 result = Variable(dims, new_data, attrs=attrs)
   2708 return result

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    346 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
    347     """
    348     Parameters
    349     ----------
   (...)
    364         unrecognized encoding items.
    365     """
--> 366     self._data = as_compatible_data(data, fastpath=fastpath)
    367     self._dims = self._parse_dimensions(dims)
    368     self._attrs = None

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:293, in as_compatible_data(data, fastpath)
    290     return data
    292 # validate whether the data is valid data types.
--> 293 data = np.asarray(data)
    295 if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
    296     data = _possibly_convert_objects(data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/jax/_src/core.py:605, in Tracer.__array__(self, *args, **kw)
    604 def __array__(self, *args, **kw):
--> 605   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3,4,5].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
jakevdp commented 1 year ago

Hi - I think I recall this kind of thing coming up before – I don't know of any effort to do the full integration of JAX and xarray that you have in mind. The problem is that xarray is fundamentally built on the assumption that its arrays are numpy arrays, and so for example np.asarray is used frequently in its implementations. For traced operations in JAX, np.asarray will return a TracerArrayConversionError, because traced values cannot be converted to NumPy arrays.

To move forward, either one of two things would have to happen:

  1. xarray would have to loosen its assumptions about the internal array representation. One way this could happen is if xarray adopted the Python Array API standard. It looks like there is some thought about this, but it would be a very big project.
  2. Somebody could write an entirely new xarray-like wrapper for JAX. I don't know of any projects like this (though I wouldn't be surprised if folks have experimented with it), but it would also be a very big project.

Short of a team of people undertaking one of those very big projects, I don't think there's any good way to do what you have in mind.

shoyer commented 1 year ago

I think would be quite exciting!

I think the Python Array API standard would probably be the way to go. Xarray's support for the API standard is pretty close to complete, and most missing features would not be hard to add. Xarray in fact already supports wrapping many types of non-NumPy arrays so this supporting JAX arrays as well would not be a big lift.

To get Xarray objects working with JAX transforms like jax.grad, they need to be registered with tree_util. But I think that is also straightforward.

Deepmind's GraphCast project contains a bundled Xarray-JAX wrapper, which I think already does some verison of both of these (maybe in a non-ideal way): https://github.com/deepmind/graphcast/blob/main/graphcast/xarray_jax.py

jakevdp commented 1 year ago

(Side note: for the Array API approach, we'd also have to land some version of #16099 to make JAX compliant)

shoyer commented 1 year ago

CC @mjwillson who wrote the Xarray-JAX wrapper in GraphCast.

tylerflex commented 1 year ago

Thanks @shoyer ! I'll have to study that graph cast code, I tried something similar but never could get it working properly.

tylerflex commented 1 year ago

I played around a bit with this GraphCast wrapper. It worked for the intended use case of applying @jax.jit to functions mapping from DataArray -> DataArray.

Unfortunately for jax.grad() still seems to give a TracerArrayConversionError. It seems like it might be occurring on the VJP function for multiplying the jax-traced scalar by the DataArray.

It's pretty likely I'm doing something wrong here so if @mjwilson / @shoyer spots something wrong here let me know!

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr

shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = DataArray(values, dims=('x', 'y', 'z'), coords=coords) # note: GraphCast wrapper class

def f(x):
    val = x * xarr_jax
    val = val.interp(x=1, y=1, z=1)
    val = jnp.array(val.data)
    return jnp.sum(val)

f(1.0) # works
jax.grad(f)(1.0) # TracerArrayConversionError
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[16], line 18
     15     return jnp.sum(val)
     17 f(1.0) # works
---> 18 jax.grad(f)(1.0) # TracerArrayConversionError

    [... skipping hidden 10 frame]

Cell In[16], line 12, in f(x)
     11 def f(x):
---> 12     val = x * xarr_jax
     13     val = val.interp(x=1, y=1, z=1)
     14     val = jnp.array(val.data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:282, in DataArrayOpsMixin.__rmul__(self, other)
    281 def __rmul__(self, other):
--> 282     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/dataarray.py:4622, in DataArray._binary_op(self, other, f, reflexive)
   4616 other_variable = getattr(other, "variable", other)
   4617 other_coords = getattr(other, "coords", None)
   4619 variable = (
   4620     f(self.variable, other_variable)
   4621     if not reflexive
-> 4622     else f(other_variable, self.variable)
   4623 )
   4624 coords, indexes = self.coords._merge_raw(other_coords, reflexive)
   4625 name = self._result_name(other)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:488, in VariableOpsMixin.__rmul__(self, other)
    487 def __rmul__(self, other):
--> 488     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:2707, in Variable._binary_op(self, other, f, reflexive)
   2703 with np.errstate(all="ignore"):
   2704     new_data = (
   2705         f(self_data, other_data) if not reflexive else f(other_data, self_data)
   2706     )
-> 2707 result = Variable(dims, new_data, attrs=attrs)
   2708 return result

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    346 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
    347     """
    348     Parameters
    349     ----------
   (...)
    364         unrecognized encoding items.
    365     """
--> 366     self._data = as_compatible_data(data, fastpath=fastpath)
    367     self._dims = self._parse_dimensions(dims)
    368     self._attrs = None

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:293, in as_compatible_data(data, fastpath)
    290     return data
    292 # validate whether the data is valid data types.
--> 293 data = np.asarray(data)
    295 if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
    296     data = _possibly_convert_objects(data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/jax/_src/core.py:605, in Tracer.__array__(self, *args, **kw)
    604 def __array__(self, *args, **kw):
--> 605   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3,4,5].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
jakevdp commented 1 year ago

The error is happening because the gradient computation results in calling xarray's __rmul__, which attempts to cast the inputs to numpy arrays, because xarray is built on the assumption that all its buffers are numpy arrays. Casting a traced JAX array to a numpy array results in a TracerConversionError.

There's no way to fix this without changing how xarray is implemented.

shoyer commented 1 year ago

xarray is built on the assumption that all its buffers are numpy arrays

This isn't true -- xarray supports a number of duck arrays. As soon as JAX implements __array_namespace__ from the array API, you'll be able to wrap JAX arrays directly into xarray objects.

If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.

jakevdp commented 1 year ago

Oh, good to know! Progress on __array_namespace__ is in #16099, though it's been hampered by the fact that JAX arrays are immutable, and some corners of the Python array API and its primary testing framework assume mutability (hopefully xarray doesn't depend on any of these mutation APIs).

tylerflex commented 1 year ago

If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.

Could you explain a bit more?

shoyer commented 1 year ago

@jakevdp Indeed, Xarray doesn't rely on the mutation APIs (unless a user tries to mutate an array)

@tylerflex I see, it looks like you were already using the GraphCast wrapper. I don't know exactly what's going on, then.

mjwillson commented 1 year ago

Hiya,

Firstly just to note that xarray_jax isn't something we're officially supporting outside the GraphCast project for now, as it does have some rough edges and is in part a bit of a stop-gap measure until JAX supports the new array protocol which will allow it to integrate better with xarray.

That said, about your example, you'll find the following very similar code works:

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
from graphcast import xarray_jax

shape = (3, 4, 5)
values = jnp.asarray(np.random.random(shape))
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = xarray_jax.DataArray(values, dims=('x', 'y', 'z'), coords=coords)

def f(x):
    val = x * xarr_jax
    val = xarray_jax.unwrap_data(val)
    return jnp.sum(val)

f(1.0)
jax.jit(f)(1.0)
jax.grad(f)(1.0)

Some issues in your code were: