google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

Unimplemented NumPy core functions #70

Closed alexbw closed 1 year ago

alexbw commented 5 years ago

Remaining functions to be implemented:

The list above was made by inspecting jnp._NOT_IMPLEMENTED and excluding deprecated functions (such as np.alen, np.ipmt, etc.), functions not relevant to JAX (such as np.setbufsize, np.ascontiguousarray, etc), and functions that modify buffers in-place (np.put, np.place, etc.):

Bugs for high-level categories:

mattjj commented 5 years ago

The list needs a bit of refinement, since some things were already implemented and other things are out of scope, like to do with loading/saving/string manipulation. I deleted a few of the out-of-scope ones, and marked as done some things that were already included and some things that are coming in #96.

rainwoodman commented 5 years ago

Is there a guide on how to write customized primitive functions (or is that even possible?). I am looking for the equivalent to autograd's defvjp.

A quick glance into the code suggests everything is delegated to XLA -- if an operator is nontrivial then it has to be first implemented in XLA?

mattjj commented 5 years ago

Great question. It's possible, and JAX's internals are actually very similar to Autograd's in this respect, but the API is a bit different. We need to write up how to do this, and maybe add a convenience layer to the API.

There are a few different use cases that we'd want to cover. One is that you just want to define a custom VJP for a function that is otherwise implemented in terms of NumPy code, like in the Autograd tutorial section. But another use case is to define a custom primitive and VJP for some external routine, like a Cython or Fortran function that isn't implemented in terms of NumPy. If you have an intended use case, does it fit into one of those categories, or is it another?

Let's track this in #116. If you have a simplified example of what you want to do, post it there and we'll use it can help guide our convenience API and/or explanation.

kovasb commented 5 years ago

Looking at contributing a function or two from this list.

What is the approach to functions with nonstatically-sized return values? For example, setxor1d.

My understanding is XLA needs to compile for each specific shape. Is there a pattern for avoiding worse-possible-case behavior, or should functions like setxor1d be removed from this list?

souravsingh commented 5 years ago

@mattjj I am looking to take a stab at the issue. How do I start?

hawkinsp commented 5 years ago

@souravsingh Awesome! In general, lots of these are fairly easy to do. Pick one you like off the list, and take a look at previous PRs implementing numpy ops. For example, I just sent out https://github.com/google/jax/pull/298 adding np.cumprod/cumsum support. Usually, it's just a question of implementing a numpy op in terms of the primitives in the lax.py library.

navneet-nmk commented 5 years ago

@mattjj I am looking to implement np.unique. Is this something which is a priority at the moment?

mattjj commented 5 years ago

Sounds like a great idea! The core team isn't working on any numpy operations at the moment (we're mostly focused on improving performance in the XLA runtime, parallel computation, and Cloud TPU support) so that's a great place to make contributions.

np.unique is tricky though because AIUI the shape of the output depends on the values of the input. Since XLA can only express computations with static shapes (meaning independent of the input values), we'd have to re-compile the XLA computation for every new value of the input, and moreover it might be tricky to express in terms of XLA HLO (i.e. in terms of lax library calls).

@hawkinsp any thoughts on implementing np.unique?

navneet-nmk commented 5 years ago

@mattjj since this would be my first hand at this, I should probably look at numpy functions that are less tricky compared to unique?

Any suggestions would be really helpful.

navneet-nmk commented 5 years ago

@mattjj Apologies for the constant posts. I have decided to implement np.diff which seems fairly easier. I was just wondering whether there is a guide to developing the testcases in the lax_numpy_test.py file. Thanks!

murphyk commented 5 years ago

+1 to adding np.cov

murphyk commented 5 years ago

+1 to np.percentile

mattjj commented 5 years ago

@murphyk np.cov coming in #983

hawkinsp commented 5 years ago

@murphyk quantile and percentile are now present, at least for interpolation='linear' and floating-point types.

adarob commented 5 years ago

fyi, np.rank is deprecated in favor of np.ndim

jessebett commented 4 years ago

Anyone interested in np.convolve I've commented on possibly using lax.conv. It would require reshaping and flipping the inputs. I've described it in #1561.

michiboo commented 4 years ago

@jessebett I just created a PR for it #1831

fedden commented 4 years ago

I'd also like to get involved in this process of completing the remaining numpy ops. As a first pass, I am looking at implementing something really simple, to dip my toes in the water so I can be more familiar with the codebase before tackling something more serious!

I will look at np.alen unless there are any issues with that :)

fedden commented 4 years ago

Created a PR for alen and isscalar in #1924

shoyer commented 4 years ago

I took another pass at deleting irrelevant functions:

I think most of the rest should be relevant, though some will be tricky (e.g., with output dependent shapes).

StephenHogg commented 4 years ago

Hi - I'm interested in getting nanmean done. Am I right in suggesting that this could be done by somehow piggybacking off nansum? If so, are there any tips that may be worth heeding in this regard?

shoyer commented 4 years ago

NumPy should be a pretty good model for most NaN functions, all of its NaN functions are written in pure Python.

shoyer commented 4 years ago

To make things a little easier to keep track of, I split some functions off into a handful of sub-issues:

oliverastrand commented 4 years ago

Hi, should np.asanyarray be removed since it simply becomes np.asarray when there is no subclassing?

jameskirkpatrick commented 4 years ago

np.cast is not implemented in Jax. this is a dict of lambdas that allows casting to certain types.

hawkinsp commented 4 years ago

@jameskirkpatrick I agree that np.cast exists in my copy of NumPy, but it is not a documented API as far as I can tell. Are we sure it's an intentional, non-deprecated NumPy API? @shoyer any thoughts?

bhushan-borole commented 3 years ago

Hello @hawkinsp @jakevdp I am willing to take up np.insert, np. min_scalar_type. Is someone else working on it?

jakevdp commented 3 years ago

I don't know of anyone working on these functions, although I think min_scalar_type cannot be made compatible with JIT because of its semantics. Please let us know if you have questions!

jakevdp commented 3 years ago

I updated the list to only include remaining functions that would be good candidates for implementation in JAX.

avani17101 commented 3 years ago

np.poly, np.polyfit and np.poly1d have transitioned as mentioned in https://numpy.org/doc/stable/reference/routines.polynomials.html. Shall we transition it as well?

jakevdp commented 3 years ago

I'm not sure whether we'd want to try to replicate numpy's polynomial interface in JAX. The nice thing about np.poly* functions is that they basically just operate on arrays, which fits JAX's model more readily than does manipulation of custom polynomial objects.

ntlm1686 commented 2 years ago

Hi @jakevdp @hawkinsp! I am interested in implementing np.polydiv. Is someone else working on it?

jakevdp commented 2 years ago

There was a start on it in #7729, but that PR seems to have stalled. Feel free to work on it!

Gairick52 commented 1 year ago

@alexbw Hello sir i want to contribute can you point me to some resources

jakevdp commented 1 year ago

Hi @Gairick52 - I removed the "good first issue" label because this issue is nearly complete, and the remaining TODOs are pretty difficult (they'd likely involve updates to XLA)

As far as contributing to JAX, the best approach I think would be to find something that overlaps with your own areas of expertise. What kinds of things do you use JAX for?

alexbw commented 1 year ago

https://youtu.be/xP8tFFJrtXU