jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

Overrides of NumPy functions on JAX arrays #1565

Open shoyer opened 4 years ago

shoyer commented 4 years ago

NumPy has protocols, based on the __array_ufunc__ and __array_function__ methods, that allow for overriding what NumPy functions like np.sin() and np.concatenate when called on other array types.

In practice, this means users can write import numpy as np to get NumPy functions that work on JAX arrays instead of needing to write import jax.numpy as np.

It might make sense to implement these methods on JAX's array objects. A working prototype of this can be found in https://github.com/google/jax/pull/611.

Reason to do this:

Reasons not to do this:

Decision by @mattjj and myself: We're not going merge this yet, because it's not clear that anyone would even use it and it imposes a maintenance burden.

If you have compelling use-cases, please speak up. We could relatively easily make this happen, but would need someone who could commit to being a passionate user first.

lukasheinrich commented 4 years ago

adherence to NEP13 and NEP18 would make it useful to integrate jax into projects that rely on them for portability. Specifically we're looking to integrate jax w/ scale-out systems like e.g. dask and particle physics libraries like https://github.com/scikit-hep/awkward-array. @jpivarski can probably comment better on the technical details but we'd very much be a passionate user :)

Hoeze commented 4 years ago

I love the imagination of xarray with jax in the back... Would be so awesome! Also, it's quite unfortunate that Tensorflow/Jax/... all have different APIs compared to numpy.

sursu commented 4 years ago

An example:

N = lambda x: stats.norm.cdf(x)

def test(a, b):
    return N((b-a)/np.sqrt(a))

Jake's function (in the mentioned issue above), being meant only for illustrative purposes, allows me to @jaxify only the test function. This function calls N which does not use the jax.scipy.stats and therefore I will get an error if I try to compute the grad.

Would it be possible to override all the numpy and scipy instances from within the function I want to differentiate and all other methods being called from within this main function?

cranmer commented 4 years ago

In the context of a large software effort for the LHC (http://iris-hep.org) we are discussing this as @lukasheinrich mentioned above. We have jagged arrays and we have been able to override ufunc to allow numpy to run over our data structures. We would like to be able to do this with Jax.

lukasheinrich commented 4 years ago

as a minimal example this should work

pip install jax jaxlib numpy awkward`
python
>>> import awkward1
>>> import numpy as np
>>> import jax.numpy as jnp
>>> a = awkward1.from_iter([[1,2,3],[],[4,5]])
>>> np.power(a,2)
<Array [[1, 4, 9], [], [16, 25]] type='3 * var * int64'>
>>> jnp.power(a,2)
>>> jnp.power(a,2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 532, in power
    return lax.integer_pow(x1, x2)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/lax/lax.py", line 265, in integer_pow
    return integer_pow_p.bind(x, y=y)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/core.py", line 211, in bind
    return self.impl(*args, **kwargs)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 217, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 209, in arg_spec
    aval = abstractify(x)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 159, in abstractify
    raise TypeError(f"No abstraction handler for type: {type(x)}")
TypeError: No abstraction handler for type: <class 'awkward1.highlevel.Array'>

the error message suggests that there are pluggable "abstraction handlers". If there iis a well defined protocol we could maybe implement one for awkward1.highlevel.Array arrays

mhlr commented 4 years ago

If this automated or at least simplified postin sckiit to JAX this would be huge!

peterdsharpe commented 2 years ago

Edit: nevermind this comment! I updated JAX to find that __array_module__ has been implemented. Thank you!

shoyer commented 2 years ago

Edit: nevermind this comment! I updated JAX to find that __array_module__ has been implemented. Thank you!

JAX has __array_module__, but I don't think NEP 37 is ever going to be accepted. NEP 47 (__array_namespace__ / array API standard) has much more momentum behind it, e.g., a PyTorch implementation.

raj-magesh commented 5 months ago

I'm curious if NEP 47 is supported (or planned) for JAX. It would be nice to transparently use xarray over Jax primitives.

NeilGirdhar commented 5 months ago

@raj-magesh https://github.com/google/jax/issues/18353

raj-magesh commented 5 months ago

That's excellent, thank you! Looks like it's shaping up brilliantly. I'm especially happy that the linear algebra primitives are almost all done!

NeilGirdhar commented 5 months ago

@raj-magesh I'm excited too! The Jax team are finishing it so fast.