jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.42k stars 2.79k forks source link

Rotation.concatenate does not work for two single rotations #23202

Open thomas-rkk opened 2 months ago

thomas-rkk commented 2 months ago

Description

Currently, jax fails to concatenate instances of jax.scipy.spatial.transform.Rotation correctly, when they are both single rotations. Code to reproduce:

import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as jRotation

q1 = jnp.array([0.0, 0.0, 1.0, 0.0])
q2 = jnp.array([0.0, 0.0, 0.0, 1.0])

r1 = jRotation.from_quat(q1)
r2 = jRotation.from_quat(q2)

r3 = jRotation.concatenate([r1, r2])

print(r3.as_quat())
print(r3.as_rotvec())

Expected output:

[[0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[0.        0.        3.1415927]
 [0. 0. 0.]]

Current output:

[0. 0. 1. 0. 0. 0. 0. 1.]
[0.        0.        3.1415927]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.1.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='development-1', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May  7 09:00:52 UTC 2', machine='x86_64')

$ nvidia-smi
Thu Aug 22 23:42:23 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3060        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   44C    P8             15W /  170W |    1012MiB /  12288MiB |     29%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1567      G   /usr/lib/xorg/Xorg                            667MiB |
|    0   N/A  N/A      2169      G   cinnamon                                       47MiB |
|    0   N/A  N/A      3937      G   /usr/lib/firefox/firefox                        0MiB |
|    0   N/A  N/A     21524      G   ...yOnDemand --variations-seed-version         91MiB |
|    0   N/A  N/A     75918      G   ...erProcess --variations-seed-version        134MiB |
|    0   N/A  N/A   1284178      G   ...96,262144 --variations-seed-version         19MiB |
+-----------------------------------------------------------------------------------------+
jakevdp commented 2 months ago

Hi - thanks for the report! The Rotation functionality has some implementation issues, and is a part of the package that we've identified (retroactively) as out-of-scope for JAX (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-spatial), and at some point in the future it will probably be removed.

My hope is that ongoing efforts to make scipy compatible with the Python array API will allow JAX users to replace these tools with using the scipy rotation code directly, although that's not yet possible.

In the meantime, is this an issue that you can work around?

thomas-rkk commented 1 month ago

Hey, thanks for answering and sorry for taking so long to get back to you. We can definitely work around this issue.

It is funny that you mention the array API. From what I can understand from the scipy issue (https://github.com/scipy/scipy/issues/18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.

jakevdp commented 1 month ago

It is funny that you mention the array API. From what I can understand from the scipy issue (scipy/scipy#18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.

I think you're misreading that issue: while it's true that some operations will be dispatched to other libraries, my read is that this is limited to special functions which cannot be efficiently implemented in terms of the array API standard. Rotation does not fall into this category: it is an object-oriented API around operations that are easily expressible in terms of the API standard, and scipy is hard at work rewriting such APIs in terms of the array API. This is tracked in https://github.com/scipy/scipy/issues/18867.