Open gderossi opened 2 years ago
In general, if something is in the NumPy API and it's useful, we welcome PRs! I don't know whether or not those particular APIs will be ergonomic or useful given the various constraints of JAX, but the best way to find that out would be to write a prototype, as you are doing. Let us know!
I think the Chebyshev code is ready for a PR, but I'm running into some issues with the tests. pytest -n auto tests/ gets to about 70-75% before one of the nodes crashes and is replaced. I can't see progress numbers after that, but the tests continue for a while longer before stalling forever or crashing. This happens both with my additions and with a clean copy of the main JAX branch. I'm working on an 8-year-old laptop, so the hardware might be the problem. Any suggestions?
I think the best thing you can do at this point is narrow it down to a single test that fails, and then share instructions to reproduce and I can take a look. Crashing is never supposed to happen, so that's a bug.
Running all of the tests individually one by one worked fine, for both the main branch and for my modified branch, so I think the crashes and stalling must have been hardware issues on my end. I did get one failure in pjit_test.py for both versions of the library though, which reported that "Some donated buffers were not usable" and "Donation is not implemented for cpu". This might be a bug? I would have expected this test to be skipped if it wasn't supported on CPU only, but perhaps not.
If you have a way for me to reproduce the crash, even using many tests, I'd like to try. JAX should never crash.
The buffer donation warning sounds like there's a test that should have been disabled but wasn't. Which test was it?
I've been unable to reproduce the crash, but the test that is failing is testLowerDonateArgnumsAvailable in pjit_test.py.
After the closure of #11903 I finally got around to implementing my own package for working with polynomials in jax: https://github.com/f0uriest/orthax
Mostly replicates the API of numpy.polynomial
and plan on extending beyond that.
I am working on a project that utilizes Chebyshev polynomials and might use other types of polynomials in the future, and I would like to leverage jit and auto differentiation through JAX. numpy supports Chebyshev and other polynomial types through the new polynomial package API, but this package has not been implemented in JAX. There was some brief discussion in #70 about whether it was even desirable to implement this interface in JAX, and I was wondering if there were any updates or a conclusion to this discussion.
I have already started to implement the Chebyshev convenience class and convenience functions (chebadd, chebmul, etc.) in a branch for internal use, but merging them into the main branch would be great for long-term support and ease of access.
Thanks,