Open thomas-rkk opened 2 months ago
Hi - thanks for the report! The Rotation
functionality has some implementation issues, and is a part of the package that we've identified (retroactively) as out-of-scope for JAX (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-spatial), and at some point in the future it will probably be removed.
My hope is that ongoing efforts to make scipy compatible with the Python array API will allow JAX users to replace these tools with using the scipy rotation code directly, although that's not yet possible.
In the meantime, is this an issue that you can work around?
Hey, thanks for answering and sorry for taking so long to get back to you. We can definitely work around this issue.
It is funny that you mention the array API. From what I can understand from the scipy issue (https://github.com/scipy/scipy/issues/18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.
It is funny that you mention the array API. From what I can understand from the scipy issue (scipy/scipy#18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.
I think you're misreading that issue: while it's true that some operations will be dispatched to other libraries, my read is that this is limited to special functions which cannot be efficiently implemented in terms of the array API standard. Rotation
does not fall into this category: it is an object-oriented API around operations that are easily expressible in terms of the API standard, and scipy is hard at work rewriting such APIs in terms of the array API. This is tracked in https://github.com/scipy/scipy/issues/18867.
Description
Currently, jax fails to concatenate instances of jax.scipy.spatial.transform.Rotation correctly, when they are both single rotations. Code to reproduce:
Expected output:
Current output:
System info (python version, jaxlib version, accelerator, etc.)