Open shoyer opened 4 years ago
adherence to NEP13 and NEP18 would make it useful to integrate jax into projects that rely on them for portability. Specifically we're looking to integrate jax w/ scale-out systems like e.g. dask and particle physics libraries like https://github.com/scikit-hep/awkward-array. @jpivarski can probably comment better on the technical details but we'd very much be a passionate user :)
I love the imagination of xarray with jax in the back... Would be so awesome! Also, it's quite unfortunate that Tensorflow/Jax/... all have different APIs compared to numpy.
An example:
N = lambda x: stats.norm.cdf(x)
def test(a, b):
return N((b-a)/np.sqrt(a))
Jake's function (in the mentioned issue above), being meant only for illustrative purposes,
allows me to @jaxify
only the test
function. This function calls N
which does not use the jax.scipy.stats
and therefore I will get an error if I try to compute the grad
.
Would it be possible to override all the numpy
and scipy
instances from within the function I want to differentiate and all other methods being called from within this main function?
In the context of a large software effort for the LHC (http://iris-hep.org) we are discussing this as @lukasheinrich mentioned above. We have jagged arrays and we have been able to override ufunc
to allow numpy to run over our data structures. We would like to be able to do this with Jax.
as a minimal example this should work
pip install jax jaxlib numpy awkward`
python
>>> import awkward1
>>> import numpy as np
>>> import jax.numpy as jnp
>>> a = awkward1.from_iter([[1,2,3],[],[4,5]])
>>> np.power(a,2)
<Array [[1, 4, 9], [], [16, 25]] type='3 * var * int64'>
>>> jnp.power(a,2)
>>> jnp.power(a,2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 532, in power
return lax.integer_pow(x1, x2)
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/lax/lax.py", line 265, in integer_pow
return integer_pow_p.bind(x, y=y)
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/core.py", line 211, in bind
return self.impl(*args, **kwargs)
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 217, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 209, in arg_spec
aval = abstractify(x)
File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 159, in abstractify
raise TypeError(f"No abstraction handler for type: {type(x)}")
TypeError: No abstraction handler for type: <class 'awkward1.highlevel.Array'>
the error message suggests that there are pluggable "abstraction handlers". If there iis a well defined protocol we could maybe implement one for awkward1.highlevel.Array
arrays
If this automated or at least simplified postin sckiit to JAX this would be huge!
Edit: nevermind this comment! I updated JAX to find that __array_module__
has been implemented. Thank you!
Edit: nevermind this comment! I updated JAX to find that
__array_module__
has been implemented. Thank you!
JAX has __array_module__
, but I don't think NEP 37 is ever going to be accepted. NEP 47 (__array_namespace__
/ array API standard) has much more momentum behind it, e.g., a PyTorch implementation.
I'm curious if NEP 47 is supported (or planned) for JAX. It would be nice to transparently use xarray over Jax primitives.
@raj-magesh https://github.com/google/jax/issues/18353
That's excellent, thank you! Looks like it's shaping up brilliantly. I'm especially happy that the linear algebra primitives are almost all done!
@raj-magesh I'm excited too! The Jax team are finishing it so fast.
NumPy has protocols, based on the
__array_ufunc__
and__array_function__
methods, that allow for overriding what NumPy functions likenp.sin()
andnp.concatenate
when called on other array types.In practice, this means users can write
import numpy as np
to get NumPy functions that work on JAX arrays instead of needing to writeimport jax.numpy as np
.It might make sense to implement these methods on JAX's array objects. A working prototype of this can be found in https://github.com/google/jax/pull/611.
Reason to do this:
import numpy as np
and it will probably work. This is particularly advantageous for third-party libraries (e.g., for projects like opt-einsum or xarray) that want to support multiple backends in a clean, composable way.Reasons not to do this:
onp.asarray()
. https://github.com/google/jax/pull/611 includes a handful of examples of this internally in JAX.Decision by @mattjj and myself: We're not going merge this yet, because it's not clear that anyone would even use it and it imposes a maintenance burden.
If you have compelling use-cases, please speak up. We could relatively easily make this happen, but would need someone who could commit to being a passionate user first.