Closed alexbw closed 1 year ago
The list needs a bit of refinement, since some things were already implemented and other things are out of scope, like to do with loading/saving/string manipulation. I deleted a few of the out-of-scope ones, and marked as done some things that were already included and some things that are coming in #96.
Is there a guide on how to write customized primitive functions (or is that even possible?). I am looking for the equivalent to autograd's defvjp
.
A quick glance into the code suggests everything is delegated to XLA -- if an operator is nontrivial then it has to be first implemented in XLA?
Great question. It's possible, and JAX's internals are actually very similar to Autograd's in this respect, but the API is a bit different. We need to write up how to do this, and maybe add a convenience layer to the API.
There are a few different use cases that we'd want to cover. One is that you just want to define a custom VJP for a function that is otherwise implemented in terms of NumPy code, like in the Autograd tutorial section. But another use case is to define a custom primitive and VJP for some external routine, like a Cython or Fortran function that isn't implemented in terms of NumPy. If you have an intended use case, does it fit into one of those categories, or is it another?
Let's track this in #116. If you have a simplified example of what you want to do, post it there and we'll use it can help guide our convenience API and/or explanation.
Looking at contributing a function or two from this list.
What is the approach to functions with nonstatically-sized return values? For example, setxor1d.
My understanding is XLA needs to compile for each specific shape. Is there a pattern for avoiding worse-possible-case behavior, or should functions like setxor1d be removed from this list?
@mattjj I am looking to take a stab at the issue. How do I start?
@souravsingh Awesome! In general, lots of these are fairly easy to do. Pick one you like off the list, and take a look at previous PRs implementing numpy ops. For example, I just sent out https://github.com/google/jax/pull/298 adding np.cumprod/cumsum support. Usually, it's just a question of implementing a numpy op in terms of the primitives in the lax.py
library.
@mattjj I am looking to implement np.unique. Is this something which is a priority at the moment?
Sounds like a great idea! The core team isn't working on any numpy operations at the moment (we're mostly focused on improving performance in the XLA runtime, parallel computation, and Cloud TPU support) so that's a great place to make contributions.
np.unique
is tricky though because AIUI the shape of the output depends on the values of the input. Since XLA can only express computations with static shapes (meaning independent of the input values), we'd have to re-compile the XLA computation for every new value of the input, and moreover it might be tricky to express in terms of XLA HLO (i.e. in terms of lax
library calls).
@hawkinsp any thoughts on implementing np.unique
?
@mattjj since this would be my first hand at this, I should probably look at numpy functions that are less tricky compared to unique?
Any suggestions would be really helpful.
@mattjj Apologies for the constant posts. I have decided to implement np.diff which seems fairly easier. I was just wondering whether there is a guide to developing the testcases in the lax_numpy_test.py file. Thanks!
+1 to adding np.cov
+1 to np.percentile
@murphyk np.cov coming in #983
@murphyk quantile
and percentile
are now present, at least for interpolation='linear'
and floating-point types.
fyi, np.rank is deprecated in favor of np.ndim
Anyone interested in np.convolve
I've commented on possibly using lax.conv
. It would require reshaping and flipping the inputs. I've described it in #1561.
@jessebett I just created a PR for it #1831
I'd also like to get involved in this process of completing the remaining numpy ops. As a first pass, I am looking at implementing something really simple, to dip my toes in the water so I can be more familiar with the codebase before tackling something more serious!
I will look at np.alen unless there are any issues with that :)
Created a PR for alen and isscalar in #1924
I took another pass at deleting irrelevant functions:
np.matrix
.I think most of the rest should be relevant, though some will be tricky (e.g., with output dependent shapes).
Hi - I'm interested in getting nanmean
done. Am I right in suggesting that this could be done by somehow piggybacking off nansum
? If so, are there any tips that may be worth heeding in this regard?
NumPy should be a pretty good model for most NaN functions, all of its NaN functions are written in pure Python.
To make things a little easier to keep track of, I split some functions off into a handful of sub-issues:
Hi, should np.asanyarray be removed since it simply becomes np.asarray when there is no subclassing?
np.cast is not implemented in Jax. this is a dict of lambdas that allows casting to certain types.
@jameskirkpatrick I agree that np.cast
exists in my copy of NumPy, but it is not a documented API as far as I can tell. Are we sure it's an intentional, non-deprecated NumPy API? @shoyer any thoughts?
Hello @hawkinsp @jakevdp
I am willing to take up np.insert
, np. min_scalar_type
.
Is someone else working on it?
I don't know of anyone working on these functions, although I think min_scalar_type
cannot be made compatible with JIT because of its semantics. Please let us know if you have questions!
I updated the list to only include remaining functions that would be good candidates for implementation in JAX.
np.poly, np.polyfit and np.poly1d have transitioned as mentioned in https://numpy.org/doc/stable/reference/routines.polynomials.html. Shall we transition it as well?
I'm not sure whether we'd want to try to replicate numpy's polynomial interface in JAX. The nice thing about np.poly*
functions is that they basically just operate on arrays, which fits JAX's model more readily than does manipulation of custom polynomial objects.
Hi @jakevdp @hawkinsp!
I am interested in implementing np.polydiv
. Is someone else working on it?
There was a start on it in #7729, but that PR seems to have stalled. Feel free to work on it!
@alexbw Hello sir i want to contribute can you point me to some resources
Hi @Gairick52 - I removed the "good first issue" label because this issue is nearly complete, and the remaining TODOs are pretty difficult (they'd likely involve updates to XLA)
As far as contributing to JAX, the best approach I think would be to find something that overlaps with your own areas of expertise. What kinds of things do you use JAX for?
Remaining functions to be implemented:
The list above was made by inspecting
jnp._NOT_IMPLEMENTED
and excluding deprecated functions (such asnp.alen
,np.ipmt
, etc.), functions not relevant to JAX (such asnp.setbufsize
,np.ascontiguousarray
, etc), and functions that modify buffers in-place (np.put
,np.place
, etc.):Bugs for high-level categories: