google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.04k stars 2.66k forks source link

Array dispatching with __array_ufunc__ in JAX #22090

Open gautierronan opened 1 week ago

gautierronan commented 1 week ago

Is it possible to dispatch jax functions to custom array-like classes?

For instance, in the example below, I have a class that represents an array with non-zero elements on the main diagonal (only those are stored). I would like to use my custom _exp method which requires computing the exponential of each element only on the diagonal elements.

I can achieve this in NumPy using __array_ufunc__. Is there any equivalent way in JAX?

import jax.numpy as jnp
import numpy as np

class DiagonalArray:
    def __init__(self, diagonal):
        self.diagonal = jnp.asarray(diagonal)

    def __jax_array__(self):
        return jnp.diag(self.diagonal)

    def _exp(self):
        print('_exp called')
        return DiagonalArray(jnp.exp(self.diagonal))

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if ufunc in (jnp.exp, np.exp):
            return self._exp()
        else:
            return NotImplemented

x = DiagonalArray([1, 2, 3])
np.exp(x)  # _exp called
jnp.exp(x) # _exp not called
jakevdp commented 1 week ago

Thanks for the question!

We've talked about this, but have not implemented any support for __array_ufunc__ or other custom dispatch mechanisms. The thinking is that it adds too much complexity and indirection to the API, and for the user makes it harder to reason about what a particular line of code might be doing.

Instead, I'd suggest going the Array API route – for example, you can currently do from jax.experimental import array_api as xp, and then get an array API namespace that works on JAX arrays. Similarly, you could create an xp namespace that implements the numpy namespace you need for your own types, and then write your code in a way that dispatches to the correct namespace.

This is the solution that the numpy community landed on after experimenting with __array_ufunc__, __array_function__, and other dispatch mechanisms, and finding them difficult to support. Libraries like scipy and sklearn are now working on array API support within their own APIs, so doing this would make your custom dispatch compatible with libraries beyond JAX as well.

What do you think?

gautierronan commented 1 week ago

Thanks for your quick answer! I have been looking at the Array API documentation you linked, but it's not so clear to me how to achieve the exemple above using this.

Is there any exemple of usage of jax.experimental.array_api? Or would you have a short implementation of my usecase above?

Our general need is quite simple. We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array) which would return either a JAX array or a custom class array (both are fine).

jakevdp commented 1 week ago

Here's an example of how this might look using the array API:

import jax.experimental.array_api # side-effecting import required now, but won't be needed in the future

import jax.numpy as jnp

class DiagonalArrayNamespace:
  @staticmethod
  def exp(x):
    return DiagonalArray(jnp.exp(x.diagonal))

class DiagonalArray:
  def __init__(self, diagonal):
      self.diagonal = jnp.asarray(diagonal)

  def __array_namespace__(self):
    return DiagonalArrayNamespace()

def func(x):
  xp = x.__array_namespace__()
  return xp.exp(x)

x = jnp.arange(4)
print(func(x))
# [ 1.         2.7182817  7.389056  20.085537 ]

d = DiagonalArray(x)
print(func(d))
# <__main__.DiagonalArray object at 0x7a3b784592a0>

The benefit of this approach is that func will be compatible with any type that implements the array API. Additionally, as long as you implement the relevant parts of the namespace for your type, it could be used directly with other packages that are array API aware, such as future versions of scipy and scikit-learn.

We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array) which would return either a JAX array or a custom class array (both are fine).

We do not have any plans to support this kind of custom overloading of the jax.numpy API. Instead, you should write your code to use the Array API, and pass it types that provide __array_namespace__, such as JAX arrays, numpy arrays, or your your own custom arrays.

gautierronan commented 6 days ago

Thanks for the exemple @jakevdp, it's very useful.

Indeed, it seems there are two different usecases:

  1. writing custom functions (e.g. fancy_exp) that accept arrays from any library (JAX, NumPy, PyTorch, scipy, ...)
  2. writing custom arrays that are compatible with functions from any library (jnp.exp, np.exp, torch.exp, ...), and that can override these certain of these functions when useful (e.g. because of speedup)

Your solution indeed supports 1, but we'd rather be looking for 2. What is the reasoning behind not wanting to support 2? NumPy seems to be allowing it.

jakevdp commented 6 days ago

(2) will never be feasible in the long term. The numpy team tried many approaches over the years to make this work (starting with python operator dispatch & __array__, then trying dispatching numpy functions to object methods (this is why e.g. np.reshape(x) first tries x.reshape()) then trying __array_function__, then __array_ufunc__, etc.) and in the end found that none of these approaches were sufficient in practice. The partialy-implemented mechanisms still exist in the codebase for backward compatibility, but are not a mechanism that the NumPy team recommends.

The problem is that the numpy API has no formal specification (there are many strange, underspecified implementation details that may be dependend upon in the logic of any particular library that uses it), and this makes the above approaches too brittle to be used broadly and depended upon by the larger ecosystem. Further, it makes it much more difficult to reason about what a particular piece of code is doing when it may be opaquely dispatching to arbitrary implementations defined by the object you pass to it.

The solution in the end was to define a well-specified subset of the numpy API for which downstream libraries can implement the full semantics in a well-defined and transparent way: this is the Array API, and that is the solution that numpy developers will point you to if you ask them the same question today.

For JAX, we want to learn from the experience of the NumPy team and just implement the final, working solution rather than one of the many half-working false starts that eventually led to that working solution.

gautierronan commented 6 days ago

Ok that makes sense. Thanks a lot for your detailed answer.

jakevdp commented 6 days ago

I'd add that (1) is probably the best way forward for what you have in mind too: if the downstream tools that you use adapt their implementations to make use of the Array API standard, then you actually have a hope of supporting this set of functionality for your own custom type. Of course, this requires downstream implementations to change, but without that I don't think you'd ever land on a robust solution to implementing all the corner cases of these similar-but-different APIs for your own type.

patrick-kidger commented 6 days ago

I'll offer one other option here, which is Quax. This allows you to define array-ish types and then perform dispatch on them. One of the examples we have is for LoRA.

This basically is a direct solution to (2): define how your custom type interacts with each JAX primitive and in principle you have compatibility with arbitrary JAX code. However the downsides are (a) needing to define how this works with every primitive you want to interact with, and (b) it's JAX-only: no NumPy/PyTorch/etc. So no free lunch I guess.