jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.23k stars 2.77k forks source link

[Pallas] When mixing basic indexing and integer array indexing, the axis corresponding to integer array indexing is unnecessarily moved to the front #22783

Open ayaka14732 opened 2 months ago

ayaka14732 commented 2 months ago

Description

I am testing in interpret mode.

Repro:

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

x_shape = (16, 3)

x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)

a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[::4, a]

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
    interpret=True,
)
def kernel(x_ref, o_ref):
    o_ref[...] = x_ref[::4, a]

y_ = kernel(x)
np.testing.assert_array_equal(y_, y)

Expected behavior:

The line y_ = kernel(x) should run successfully, and yield the same value as y.

Actual behavior:

Traceback (most recent call last):
  File "/home/ayx/development/jax/test.py", line 23, in <module>
    y_ = kernel(x)
         ^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 1085, in wrapped
    grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
                                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 857, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/test.py", line 21, in kernel
    o_ref[...] = x_ref[::4, a]
    ~~~~~^^^^^
  File "/home/ayx/development/jax/jax/_src/numpy/array_methods.py", line 747, in op
    return getattr(self.aval, f"_{name}")(self, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/types.py", line 187, in _setitem
    return ref_set(tracer, idx, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 124, in ref_set
    ref_swap(ref_or_view, idx, value, _function_name="ref_set")
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 120, in ref_swap
    return swap_p.bind(ref, value, *flat_indexers, tree=tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 188, in _swap_abstract_eval
    raise ValueError("Invalid shape for `swap`. "
ValueError: Invalid shape for `swap`. Ref shape: (4, 5). Expected shape: (4, 5). Value shape: (5, 4). Indices: (NDIndexer(indices=(Slice(start=0, size=4, stride=1), Slice(start=0, size=5, stride=1)), shape=(4, 5), int_indexer_shape=(), validate=False),). 
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Explanation:

The correct shape of the resulting array should be (4, 5), but in Pallas, the shape is incorrectly assumed to be (5, 4), thus resulting the error.

I have tested various indexing and observed a pattern that when there is only 1 integer array indexing, the axis corresponding to it is always unnecessarily moved to the front. For example, in the above case, the axis with shape 5 is moved to the front, making Pallas to assume the shape to be (5, 4) instead of (4, 5).

This may have to do with https://github.com/google/jax/blob/5c9bb612a775ca23d311eef1aeac03dfe0828a62/jax/_src/state/indexing.py#L256-L257.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31.dev20240729+6a7822a73
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.4 (main, Jun 12 2024, 19:06:53) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.6.15-2rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.6.15-2rodete2 (2024-03-19)', machine='x86_64')
ayaka14732 commented 2 months ago

Better repro (without strided indexing):

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

x_shape = (16, 3)

x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)

a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[:, a]

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
    interpret=True,
)
def kernel(x_ref, o_ref):
    o_ref[...] = x_ref[:, a]

y_ = kernel(x)
np.testing.assert_array_equal(y_, y)

Error:

Traceback (most recent call last):
  File "/home/ayx/development/jax/4.py", line 23, in <module>
    y_ = kernel(x)
         ^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 1129, in wrapped
    jaxpr = _trace_kernel_to_jaxpr(
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 901, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/4.py", line 21, in kernel
    o_ref[...] = x_ref[:, a]
    ~~~~~^^^^^
  File "/home/ayx/development/jax/jax/_src/numpy/array_methods.py", line 749, in op
    return getattr(self.aval, f"_{name}")(self, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/types.py", line 187, in _setitem
    return ref_set(tracer, idx, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 114, in ref_set
    ref_swap(ref_or_view, idx, value, _function_name="ref_set")
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 110, in ref_swap
    return swap_p.bind(ref, value, *flat_indexers, tree=tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 178, in _swap_abstract_eval
    raise ValueError("Invalid shape for `swap`. "
ValueError: Invalid shape for `swap`. Ref shape: (16, 5). Expected shape: (16, 5). Value shape: (5, 16). Indices: (NDIndexer(indices=(Slice(start=0, size=16, stride=1), Slice(start=0, size=5, stride=1)), shape=(16, 5), int_indexer_shape=(), validate=False),). 
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.