google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.07k stars 2.66k forks source link

Add `where` argument to `argmax`, `argmin`, `ptp`, `cumsum`, `cumprod` #20177

Open carlosgmartin opened 3 months ago

carlosgmartin commented 3 months ago

The following functions receive a where argument, which limits the reduction to a given boolean mask:

Feature request: Add a where argument to the following functions:

Related:

jakevdp commented 3 months ago

Hi - thanks for the request! We generally follow the NumPy API in jax.numpy, and as far as I can tell, numpy does not support a where argument for any of these functions.

carlosgmartin commented 3 months ago

@jakevdp Recalling what we did for jax.numpy.fill_diagonal with inplace (vs. numpy.fill_diagonal):

I think that, since the API will be identical/compatible when this functionality is added to numpy (see the linked issue in the OP), JAX implementing it ahead of time should not cause any issues.

Personally, I'd find it very helpful to have this functionality.

jakevdp commented 3 months ago

fill_diagonal is somewhat different: it's impossible to implement that in JAX without changing the API to avoid in-place modification.

Adding additional functionality to existing APIs is a qualitatively different discussion. There is some upside in flexibility, but it has some downsides too, namely:

  1. It increases the user-facing API surface that we need to support for the rest of time.
  2. It makes jax.numpy harder to test, because we must also implement the ground truth version to test against.
  3. If NumPy or the Array API eventually adds the same keyword with different semantics, it makes for an awkward deprecation (see e.g. the arr.device method deprecation that we're doing currently).
  4. It increases, if only marginally, the cognitive burden of switching between jax.numpy and original numpy.

For those reason, I lean toward not adding these sorts of keywords until they are part of either the NumPy API or the Python Array API.

carlosgmartin commented 3 months ago

@jakevdp Thanks for your response! Is there a namespace for useful array functionality that has not yet been incorporated into NumPy? Perhaps jax.lax or some theoretical jax.numpy_experimental?


Somewhat off-topic, opinionated tangent:

It seems detrimental to JAX's evolution to tie itself so strictly to NumPy. 😔

This tying has also caused issues in the past, when NumPy's way of doing things is awkward.

Is there any existing theoretical discussion on the benefits, drawbacks, and future of JAX tying itself strictly to NumPy's API?

This is perhaps a provocative opinion, and purely speculative at this point, but I'd love to see JAX liberate itself from NumPy's straightjacket. 🙂 Especially if it eclipses NumPy in terms of userbase (which I'm hoping and rooting for).

For example, I find certain aspects of PyTorch's tensor API to be superior to NumPy's array API, from a usability perspective. It's also worth noting that NumPy itself currently breaks with the Python Array API.

jakevdp commented 3 months ago

We have a related discussion here: https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html, although it kind of takes as given that we won't implement things in jax.numpy or jax.scipy that are not in numpy or scipy respectively.

JAX has plenty of APIs that exist outside of the numpy API, they are just not generally found in the jax.numpy namespace.

mattjj commented 3 months ago

For example, I find certain aspects of PyTorch's tensor API to be superior to NumPy's array API, from a usability perspective.

Very interesting! Do you have examples in mind?

carlosgmartin commented 3 months ago

@mattjj One example is how PyTorch makes it easier to chain array operations. Consider the following:

jnp.sqrt(jnp.square(abs(x.sum(2)).max(1)).sum(0))

Notice the "back-and-forth" or "spiral" pattern of the function calls. Compare it to

x.sum(2).abs().max(1).square().sum(0).sqrt()

which is much easier to write and read:

  1. It follows the logical order of the operations.
  2. It avoids using (and importing) jnp., which reduces code clutter.

In my experience, this kind of pattern (with varying lengths) is common.

Therefore, from a usability perspective, it would be convenient to attach more methods to JAX arrays themselves.

jakevdp commented 3 months ago

Thanks for elaboration. I've seen this kind of idea come up before in other contexts (e.g. https://github.com/numpy/numpy/issues/24081)

I personally don't feel strongly about this (it would be convenient in some cases, but there are costs to that large an expansion of the API surface, both in terms of maintenance burden and cognitive overhead) but I'm happy to hear what others might think.

carlosgmartin commented 3 months ago

Thanks for the link. It's indeed very relevant. It also brings up another advantage I forgot to include:

With fluent API in common, users may more easily write library-agnostic code.

I saw this in practice here: https://github.com/cvxpy/cvxpy/issues/2237.

Anecdotally, since it was mentioned in that thread, I've personally also found pandas's method-chaining API a pleasure to work with.

Unfortunately, I think NumPy's devs are making the wrong call here. If I had to choose between standalone functions and methods, I'd choose the latter, due to the aforementioned advantages from the perspective of user experience.

It's also easy to get a standalone function from the method (e.g. sin = lambda x: x.sin()), whereas the converse is not true.

jakevdp commented 3 months ago

I understand the choice the numpy devs have made. For what it’s worth, I wouldn’t suggest trying to re-open that discussion. They’ve made their opinions pretty clear, and it’s a very small team supporting a very large user community.

carlosgmartin commented 3 months ago

I understand the choice the numpy devs have made. [...] They’ve made their opinions pretty clear

But do you agree with it? Does the user community agree with it?

I won't try to steer NumPy's course on this design choice; it looks like that ship has sailed. I'm more interested in JAX's future. (Again, speaking theoretically at this point.)

It might be wise to carefully survey the community for these kinds of impactful, long-lasting design decisions. And to ensure surveyees have a good understanding of the issue, each option could be presented with a list of its pros and cons. Perhaps even an argument map.

I'd also create a space for the community to discuss the options, and keep it open for a reasonably extended period of time before committing.

Together, these could help the project establish a confident, community-wide consensus on the best direction for JAX's evolution.

Any thoughts welcome. 🙂

jakevdp commented 3 months ago

But do you agree with it?

On the whole, yes I do.

I understand the reason that you are advocating for method-based access to all ufuncs and reductions: there are certain situations in which it could make for cleaner, more concise code. But I think the biggest danger to programming languages and DSLs in the long term is bloat in the API surface, and so it is prudent to be conservative when considering addition of new APIs or duplicate spellings of existing APIs.

NumPy is quite conservative in this way, and I think that's inextricably tied to why it's remained a successful project for a quarter century. On the JAX side, we have benefitted from using numpy (and more recently the Python Array API) as a standard, because it lets us take the prudent stance without having to debate the individual merits of every single proposed API extension. I believe that is good for JAX as a whole.

Is the JAX API perfect? No. Are we willing to change it? Yes, and we add features every day! But it would not be wise to approach changes like the ones proposed here without careful consideration of the costs as well as the benefits.

carlosgmartin commented 3 weeks ago

Just making a note of this here so I don't forget when https://github.com/numpy/numpy/issues/26336 is completed: