Open tylerflex opened 1 year ago
Hi - I think I recall this kind of thing coming up before – I don't know of any effort to do the full integration of JAX and xarray that you have in mind. The problem is that xarray is fundamentally built on the assumption that its arrays are numpy arrays, and so for example np.asarray
is used frequently in its implementations. For traced operations in JAX, np.asarray
will return a TracerArrayConversionError
, because traced values cannot be converted to NumPy arrays.
To move forward, either one of two things would have to happen:
xarray
would have to loosen its assumptions about the internal array representation. One way this could happen is if xarray
adopted the Python Array API standard. It looks like there is some thought about this, but it would be a very big project.xarray
-like wrapper for JAX. I don't know of any projects like this (though I wouldn't be surprised if folks have experimented with it), but it would also be a very big project.Short of a team of people undertaking one of those very big projects, I don't think there's any good way to do what you have in mind.
I think would be quite exciting!
I think the Python Array API standard would probably be the way to go. Xarray's support for the API standard is pretty close to complete, and most missing features would not be hard to add. Xarray in fact already supports wrapping many types of non-NumPy arrays so this supporting JAX arrays as well would not be a big lift.
To get Xarray objects working with JAX transforms like jax.grad
, they need to be registered with tree_util. But I think that is also straightforward.
Deepmind's GraphCast project contains a bundled Xarray-JAX wrapper, which I think already does some verison of both of these (maybe in a non-ideal way): https://github.com/deepmind/graphcast/blob/main/graphcast/xarray_jax.py
(Side note: for the Array API approach, we'd also have to land some version of #16099 to make JAX compliant)
CC @mjwillson who wrote the Xarray-JAX wrapper in GraphCast.
Thanks @shoyer ! I'll have to study that graph cast code, I tried something similar but never could get it working properly.
I played around a bit with this GraphCast wrapper. It worked for the intended use case of applying @jax.jit
to functions mapping from DataArray -> DataArray
.
Unfortunately for jax.grad()
still seems to give a TracerArrayConversionError
. It seems like it might be occurring on the VJP function for multiplying the jax-traced scalar by the DataArray
.
It's pretty likely I'm doing something wrong here so if @mjwilson / @shoyer spots something wrong here let me know!
import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = DataArray(values, dims=('x', 'y', 'z'), coords=coords) # note: GraphCast wrapper class
def f(x):
val = x * xarr_jax
val = val.interp(x=1, y=1, z=1)
val = jnp.array(val.data)
return jnp.sum(val)
f(1.0) # works
jax.grad(f)(1.0) # TracerArrayConversionError
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
Cell In[16], line 18
15 return jnp.sum(val)
17 f(1.0) # works
---> 18 jax.grad(f)(1.0) # TracerArrayConversionError
[... skipping hidden 10 frame]
Cell In[16], line 12, in f(x)
11 def f(x):
---> 12 val = x * xarr_jax
13 val = val.interp(x=1, y=1, z=1)
14 val = jnp.array(val.data)
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:282, in DataArrayOpsMixin.__rmul__(self, other)
281 def __rmul__(self, other):
--> 282 return self._binary_op(other, operator.mul, reflexive=True)
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/dataarray.py:4622, in DataArray._binary_op(self, other, f, reflexive)
4616 other_variable = getattr(other, "variable", other)
4617 other_coords = getattr(other, "coords", None)
4619 variable = (
4620 f(self.variable, other_variable)
4621 if not reflexive
-> 4622 else f(other_variable, self.variable)
4623 )
4624 coords, indexes = self.coords._merge_raw(other_coords, reflexive)
4625 name = self._result_name(other)
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:488, in VariableOpsMixin.__rmul__(self, other)
487 def __rmul__(self, other):
--> 488 return self._binary_op(other, operator.mul, reflexive=True)
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:2707, in Variable._binary_op(self, other, f, reflexive)
2703 with np.errstate(all="ignore"):
2704 new_data = (
2705 f(self_data, other_data) if not reflexive else f(other_data, self_data)
2706 )
-> 2707 result = Variable(dims, new_data, attrs=attrs)
2708 return result
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
346 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
347 """
348 Parameters
349 ----------
(...)
364 unrecognized encoding items.
365 """
--> 366 self._data = as_compatible_data(data, fastpath=fastpath)
367 self._dims = self._parse_dimensions(dims)
368 self._attrs = None
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:293, in as_compatible_data(data, fastpath)
290 return data
292 # validate whether the data is valid data types.
--> 293 data = np.asarray(data)
295 if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
296 data = _possibly_convert_objects(data)
File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/jax/_src/core.py:605, in Tracer.__array__(self, *args, **kw)
604 def __array__(self, *args, **kw):
--> 605 raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3,4,5].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The error is happening because the gradient computation results in calling xarray
's __rmul__
, which attempts to cast the inputs to numpy arrays, because xarray is built on the assumption that all its buffers are numpy arrays. Casting a traced JAX array to a numpy array results in a TracerConversionError
.
There's no way to fix this without changing how xarray is implemented.
xarray is built on the assumption that all its buffers are numpy arrays
This isn't true -- xarray supports a number of duck arrays. As soon as JAX implements __array_namespace__
from the array API, you'll be able to wrap JAX arrays directly into xarray objects.
If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.
Oh, good to know! Progress on __array_namespace__
is in #16099, though it's been hampered by the fact that JAX arrays are immutable, and some corners of the Python array API and its primary testing framework assume mutability (hopefully xarray
doesn't depend on any of these mutation APIs).
If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.
Could you explain a bit more?
DataArray.__rmul__
? @jakevdp Indeed, Xarray doesn't rely on the mutation APIs (unless a user tries to mutate an array)
@tylerflex I see, it looks like you were already using the GraphCast wrapper. I don't know exactly what's going on, then.
Hiya,
Firstly just to note that xarray_jax isn't something we're officially supporting outside the GraphCast project for now, as it does have some rough edges and is in part a bit of a stop-gap measure until JAX supports the new array protocol which will allow it to integrate better with xarray.
That said, about your example, you'll find the following very similar code works:
import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
from graphcast import xarray_jax
shape = (3, 4, 5)
values = jnp.asarray(np.random.random(shape))
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = xarray_jax.DataArray(values, dims=('x', 'y', 'z'), coords=coords)
def f(x):
val = x * xarr_jax
val = xarray_jax.unwrap_data(val)
return jnp.sum(val)
f(1.0)
jax.jit(f)(1.0)
jax.grad(f)(1.0)
Some issues in your code were:
I've been wondering if there has been any recent progress in integrating
jax
andxarray
, specifically for automatic differentiation. For context, we have a simulation project that relies onxarray
for our simulation output data but recently addedjax
support so users can automatically differentiate through these simulations. To make this work, we added code to emulatexr.DataArray
functionality but withjax
internals. However, this approach has been a headache to maintain and extend. It would be amazing ifxarray
had native support for gradient tracking in jax.As an example, the code snippet below multiplies a Jax-traced value by an
xarray.DataArray
, does an interpolation, and then a jax-traved operation. It would be great if we could differentiate through this. The forward pass works, but the backwards pass gives aTracerArrayConversionError
.I've tried many other workarounds based on issues, such as this and some other discussions eg but without any luck. Are any updates on the status of this, whether it would be possible eventually, or suggestions for possible workarounds? Any discussion or pointers towards a good approach to this are really appreciated.
@shoyer