jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

Sparse reshape throws error when `n_dense>0` and some target dimension has size 1 #24795

Open cherrywoods opened 4 days ago

cherrywoods commented 4 days ago

Description

Reshape for sparse BCOO arrays fails if the target shape contains dimensions of size 1 and there is at least one dense dimension.

from jax.experimental import sparse

sp_id = sparse.eye(2, n_dense=1)
sp_id.reshape((1, 2, 1, 2))

Stack trace:

  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 451, in wrapped
    result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 428, in eval_sparse
    out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 539, in _sparse_rule
    result = sparse_op(*spvalues_to_arrays(spenv, spvalues), **kwds)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/bcoo.py", line 1858, in bcoo_reshape
    data = lax.reshape(
           ^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 919, in reshape
    return reshape_p.bind(
           ^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
TypeError: reshape total size must be unchanged, got new_sizes (1, 2) (of total size 2) for shape (2, 2) (of total size 4).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.0.1
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='abc', release='6.8.0-48-generic', version='#48-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 27 14:04:52 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri Nov  8 18:47:48 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GT 1030         Off | 00000000:1C:00.0  On |                  N/A |
| 35%   41C    P0              N/A /  30W |    589MiB /  2048MiB |     30%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      3509      G   /usr/lib/xorg/Xorg                          163MiB |
|    0   N/A  N/A      3774      G   /usr/bin/gnome-shell                         90MiB |
|    0   N/A  N/A      3956      G   ...irefox/5187/usr/lib/firefox/firefox      120MiB |
|    0   N/A  N/A      3998      G   ...usr/lib/thunderbird/thunderbird-bin        7MiB |
|    0   N/A  N/A      4177      G   ...esktop-client/214/usr/bin/nextcloud        0MiB |
|    0   N/A  N/A      4641      G   /usr/libexec/xdg-desktop-portal-gnome       123MiB |
|    0   N/A  N/A      5609      G   ...yOnDemand --variations-seed-version       14MiB |
|    0   N/A  N/A      9050      G   /usr/bin/nautilus                            31MiB |
|    0   N/A  N/A     25941      G   /usr/bin/gnome-calendar                      16MiB |
|    0   N/A  N/A     75544      G   /usr/bin/gnome-system-monitor                11MiB |
|    0   N/A  N/A     75930      G   ...erProcess --variations-seed-version        1MiB |
+---------------------------------------------------------------------------------------+

Note:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
cherrywoods commented 3 days ago

I came up with a hotfix but I don't think what I implemented is the generally desirable behaviour. My fix is to insert these lines:

 while i1 > 0 and new_sizes[i1 - 1] == 1:
    i1 -= 1
 while i2 > 0 and new_sizes[i2 - 1] == 1:
    i2 -= 1

after https://github.com/jax-ml/jax/blob/87ce0cbb00c8a31a0266e2b10809a184b989a2cf/jax/experimental/sparse/bcoo.py#L1864

This moves all dimensions of size one to the next dimension kind (batch -> sparse, sparse -> dense). This works in my case, but I figure there could be other cases where this might be precisely the wrong thing to do?