Open carlosgmartin opened 7 months ago
Hi - thanks for the request! We generally follow the NumPy API in jax.numpy
, and as far as I can tell, numpy does not support a where
argument for any of these functions.
@jakevdp Recalling what we did for jax.numpy.fill_diagonal with inplace
(vs. numpy.fill_diagonal):
I think that, since the API will be identical/compatible when this functionality is added to numpy (see the linked issue in the OP), JAX implementing it ahead of time should not cause any issues.
Personally, I'd find it very helpful to have this functionality.
fill_diagonal
is somewhat different: it's impossible to implement that in JAX without changing the API to avoid in-place modification.
Adding additional functionality to existing APIs is a qualitatively different discussion. There is some upside in flexibility, but it has some downsides too, namely:
jax.numpy
harder to test, because we must also implement the ground truth version to test against.arr.device
method deprecation that we're doing currently).jax.numpy
and original numpy
.For those reason, I lean toward not adding these sorts of keywords until they are part of either the NumPy API or the Python Array API.
@jakevdp Thanks for your response! Is there a namespace for useful array functionality that has not yet been incorporated into NumPy? Perhaps jax.lax
or some theoretical jax.numpy_experimental
?
Somewhat off-topic, opinionated tangent:
It seems detrimental to JAX's evolution to tie itself so strictly to NumPy. 😔
This tying has also caused issues in the past, when NumPy's way of doing things is awkward.
Is there any existing theoretical discussion on the benefits, drawbacks, and future of JAX tying itself strictly to NumPy's API?
This is perhaps a provocative opinion, and purely speculative at this point, but I'd love to see JAX liberate itself from NumPy's straightjacket. 🙂 Especially if it eclipses NumPy in terms of userbase (which I'm hoping and rooting for).
For example, I find certain aspects of PyTorch's tensor API to be superior to NumPy's array API, from a usability perspective. It's also worth noting that NumPy itself currently breaks with the Python Array API.
We have a related discussion here: https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html, although it kind of takes as given that we won't implement things in jax.numpy
or jax.scipy
that are not in numpy
or scipy
respectively.
JAX has plenty of APIs that exist outside of the numpy API, they are just not generally found in the jax.numpy
namespace.
For example, I find certain aspects of PyTorch's tensor API to be superior to NumPy's array API, from a usability perspective.
Very interesting! Do you have examples in mind?
@mattjj One example is how PyTorch makes it easier to chain array operations. Consider the following:
jnp.sqrt(jnp.square(abs(x.sum(2)).max(1)).sum(0))
Notice the "back-and-forth" or "spiral" pattern of the function calls. Compare it to
x.sum(2).abs().max(1).square().sum(0).sqrt()
which is much easier to write and read:
jnp.
, which reduces code clutter.In my experience, this kind of pattern (with varying lengths) is common.
Therefore, from a usability perspective, it would be convenient to attach more methods to JAX arrays themselves.
Thanks for elaboration. I've seen this kind of idea come up before in other contexts (e.g. https://github.com/numpy/numpy/issues/24081)
I personally don't feel strongly about this (it would be convenient in some cases, but there are costs to that large an expansion of the API surface, both in terms of maintenance burden and cognitive overhead) but I'm happy to hear what others might think.
Thanks for the link. It's indeed very relevant. It also brings up another advantage I forgot to include:
With fluent API in common, users may more easily write library-agnostic code.
I saw this in practice here: https://github.com/cvxpy/cvxpy/issues/2237.
Anecdotally, since it was mentioned in that thread, I've personally also found pandas's method-chaining API a pleasure to work with.
Unfortunately, I think NumPy's devs are making the wrong call here. If I had to choose between standalone functions and methods, I'd choose the latter, due to the aforementioned advantages from the perspective of user experience.
It's also easy to get a standalone function from the method (e.g. sin = lambda x: x.sin()
), whereas the converse is not true.
I understand the choice the numpy devs have made. For what it’s worth, I wouldn’t suggest trying to re-open that discussion. They’ve made their opinions pretty clear, and it’s a very small team supporting a very large user community.
I understand the choice the numpy devs have made. [...] They’ve made their opinions pretty clear
But do you agree with it? Does the user community agree with it?
I won't try to steer NumPy's course on this design choice; it looks like that ship has sailed. I'm more interested in JAX's future. (Again, speaking theoretically at this point.)
It might be wise to carefully survey the community for these kinds of impactful, long-lasting design decisions. And to ensure surveyees have a good understanding of the issue, each option could be presented with a list of its pros and cons. Perhaps even an argument map.
I'd also create a space for the community to discuss the options, and keep it open for a reasonably extended period of time before committing.
Together, these could help the project establish a confident, community-wide consensus on the best direction for JAX's evolution.
Any thoughts welcome. 🙂
But do you agree with it?
On the whole, yes I do.
I understand the reason that you are advocating for method-based access to all ufuncs and reductions: there are certain situations in which it could make for cleaner, more concise code. But I think the biggest danger to programming languages and DSLs in the long term is bloat in the API surface, and so it is prudent to be conservative when considering addition of new APIs or duplicate spellings of existing APIs.
NumPy is quite conservative in this way, and I think that's inextricably tied to why it's remained a successful project for a quarter century. On the JAX side, we have benefitted from using numpy (and more recently the Python Array API) as a standard, because it lets us take the prudent stance without having to debate the individual merits of every single proposed API extension. I believe that is good for JAX as a whole.
Is the JAX API perfect? No. Are we willing to change it? Yes, and we add features every day! But it would not be wise to approach changes like the ones proposed here without careful consideration of the costs as well as the benefits.
Just making a note of this here so I don't forget when https://github.com/numpy/numpy/issues/26336 is completed:
where
to argmax
, we can add where
to random.categorical
(which uses argmax
internally), removing the need to manually mask logits with -jnp.inf
.@jakevdp While this is being worked on from the NumPy side, in the meantime, would it be acceptable to add an optional mask argument to lax.argmax
, which numpy.argmax
uses internally? That way, as soon as NumPy adds the where
argument, JAX can quickly do the same.
lax
functions by design are more-or-less a direct wrapper for the underlying XLA operation; I don't think it's appropriate to add masking support there when the XLA op doesn't implement it.
I didn't find an XLA op for argmax. Is the argmax reduction defined in JAX itself?
I found the following chain of definitions: argmax > argmax_p > _compute_argminmax > _ArgMinMaxReducer.
lax.argmax
is essentially a convenience wrapper that generates a single call to lax.reduce
The following functions receive a
where
argument, which limits the reduction to a given boolean mask:Feature request: Add a
where
argument to the following functions:Related: