Closed oliverdutton closed 1 year ago
There seems to be an expectation in Jax python here wrt buffer result types that we don't match here. I'll ask Matt too what is expected.
Thanks for raising this. I think there should be an easy fix; I can take it.
(But the JAX bits being used here are super unpolished and feature coverage is still minimal!)
I think https://github.com/google/jax/pull/14986 should fix. But to be honest I've paged out a lot of context on this! So if something else breaks just let us know.
Perfect, thank you
By the way, It's near the top of my todo list to update dynamic shapes to be compatible with both JAX_JIT_PJIT_API_MERGE=1
and JAX_ARRAY=1
. It shouldn't be "hard", but it's nontrivial just because it'll require a big context switch and some time. It never rises to "urgent" like other things because we don't have any dynamic shapes users (or at least I thought we didn't have any until this issue was opened!).
I completely understand, dynamic shape is very experimental. I've been poking around
I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.
And maybe jax-md can benefit from it in the neighbor list update.
But these are primarily interest projects, hopefully I'll find a meaty business application eventually
Thanks for all the magic of jax
Thanks for the explanation, and the kind words!
I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.
Wow, this is very interesting. Any chance you could share some representative programs or toy examples, showing what you want to do, or where the compile times are killing you? Maybe we can help!
(this would also be interesting IREE side as we have a WIP pass to dedupe some kernels to dynamic dim variants to reduce compilation times, we've been focussed a bit more.in AOT case but still)
Cool, I'll generate a discussion separately and tag you in it in the next few days with a clear set of code that I'm working on compiling.
Looks like the merge solves those tests, so closing issue
I am trying to use the (very convenient) option of iree as a jax backend. Running the tests, they seem to be failing when the output shape is dynamic. I'm guessing this test actually works but I'm missing something. The issue is presumably jax forcibly turning the result into a jax array.
Below runs just one of the tests, what am I doing wrong?
The following can reproduce in colab (with jax v0.4.4)