jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Feature request: Add support for Chebyshev and other polynomials #11055

Open gderossi opened 2 years ago

gderossi commented 2 years ago

I am working on a project that utilizes Chebyshev polynomials and might use other types of polynomials in the future, and I would like to leverage jit and auto differentiation through JAX. numpy supports Chebyshev and other polynomial types through the new polynomial package API, but this package has not been implemented in JAX. There was some brief discussion in #70 about whether it was even desirable to implement this interface in JAX, and I was wondering if there were any updates or a conclusion to this discussion.

I have already started to implement the Chebyshev convenience class and convenience functions (chebadd, chebmul, etc.) in a branch for internal use, but merging them into the main branch would be great for long-term support and ease of access.

Thanks,

hawkinsp commented 2 years ago

In general, if something is in the NumPy API and it's useful, we welcome PRs! I don't know whether or not those particular APIs will be ergonomic or useful given the various constraints of JAX, but the best way to find that out would be to write a prototype, as you are doing. Let us know!

gderossi commented 2 years ago

I think the Chebyshev code is ready for a PR, but I'm running into some issues with the tests. pytest -n auto tests/ gets to about 70-75% before one of the nodes crashes and is replaced. I can't see progress numbers after that, but the tests continue for a while longer before stalling forever or crashing. This happens both with my additions and with a clean copy of the main JAX branch. I'm working on an 8-year-old laptop, so the hardware might be the problem. Any suggestions?

hawkinsp commented 2 years ago

I think the best thing you can do at this point is narrow it down to a single test that fails, and then share instructions to reproduce and I can take a look. Crashing is never supposed to happen, so that's a bug.

gderossi commented 2 years ago

Running all of the tests individually one by one worked fine, for both the main branch and for my modified branch, so I think the crashes and stalling must have been hardware issues on my end. I did get one failure in pjit_test.py for both versions of the library though, which reported that "Some donated buffers were not usable" and "Donation is not implemented for cpu". This might be a bug? I would have expected this test to be skipped if it wasn't supported on CPU only, but perhaps not.

hawkinsp commented 2 years ago

If you have a way for me to reproduce the crash, even using many tests, I'd like to try. JAX should never crash.

The buffer donation warning sounds like there's a test that should have been disabled but wasn't. Which test was it?

gderossi commented 2 years ago

I've been unable to reproduce the crash, but the test that is failing is testLowerDonateArgnumsAvailable in pjit_test.py.

f0uriest commented 8 months ago

After the closure of #11903 I finally got around to implementing my own package for working with polynomials in jax: https://github.com/f0uriest/orthax

Mostly replicates the API of numpy.polynomial and plan on extending beyond that.