google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.04k stars 2.66k forks source link

constants in Pallas kernel causes `ValueError: safe_zip() argument 2 is shorter than argument 1` #21557

Open zhixuan-lin opened 1 month ago

zhixuan-lin commented 1 month ago

Description

The following pallas kernel causes ValueError: safe_zip() argument 2 is shorter than argument 1:

import jax
import numpy as np
import jax.numpy as jnp
from jax.experimental import pallas as pl

def kernel(
    src,
    dst
):
    indices = np.arange(4).reshape(2, 2)
    dst[indices] = src[indices]

@jax.jit
def func(src):

    func = pl.pallas_call(
        f=kernel,
        out_shape=jax.ShapeDtypeStruct(src.shape, src.dtype),
        in_specs=[
            pl.BlockSpec(lambda i: (0,), src.shape)
        ],
        grid=(1,)
    )

    dst = func(src)
    return dst

if __name__ == '__main__':
    src = jnp.zeros((32))
    dst = func(src)
    dst.block_until_ready()

Detailed error log:

2024-05-31 10:28:47.756938: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Traceback (most recent call last):
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 31, in <module>
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 25, in func
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 589, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: safe_zip() argument 2 is shorter than argument 1

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 31, in <module>
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 531, in _pallas_call_lowering
    return pallas_call_registration.pallas_call_lowering(
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 161, in pallas_call_lowering
    return _pallas_call_ttir_lowering(
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 73, in _pallas_call_ttir_lowering
    lowering_result = lowering.lower_jaxpr_to_triton_module(
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 312, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 339, in lower_jaxpr_to_triton_ir
    for invar, block_info in zip(jaxpr.invars, block_infos):
ValueError: safe_zip() argument 2 is shorter than argument 1

I trace the issue to _hoist_consts_to_refs adding the constants (the array indices in the code) to jaxpr.invars. Later in lower_jaxpr_to_triton_ir jaxpr.invars is zipped with block_infos. Since block_infos do not contain block information for the constants, jaxpr.invars and block_infos have different lengths, which causes the error.

Also if I use vmap as in the following:

import jax
import numpy as np
import jax.numpy as jnp
from jax.experimental import pallas as pl

def kernel(
    src,
    dst
):
    indices = np.arange(4).reshape(2, 2)
    dst[indices] = src[indices]

@jax.jit
@jax.vmap
def func(src):

    func = pl.pallas_call(
        f=kernel,
        out_shape=jax.ShapeDtypeStruct(src.shape, src.dtype),
        in_specs=[
            pl.BlockSpec(lambda i: (0,), src.shape)
        ],
        grid=(1,)
    )

    dst = func(src)
    return dst

if __name__ == '__main__':
    src = jnp.zeros((4, 32))
    dst = func(src)
    dst.block_until_ready()

The error occurs earlier in _pallas_call_batching_rule:

2024-05-31 10:38:52.699687: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Traceback (most recent call last):
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 32, in <module>
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 26, in func
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 589, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: safe_map() argument 3 is shorter than argument 1

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 32, in <module>
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/code/benchmark/example.py", line 26, in func
    dst = func(src)
  File "/home/mila/z/zhixuan.lin/.conda/envs/linear-rnn-jax/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 420, in _pallas_call_batching_rule
    batched_block_mappings = map(
ValueError: safe_map() argument 3 is shorter than argument 1

I do not know much about jaxpr so I'm not sure what I should do here. Any pointers are appreciated. Thanks!

System info (python version, jaxlib version, accelerator, etc.)

Following is system info. I've also tried 0.4.28 but got the same error.


>>> import jax; jax.print_environment_info()
jax:    0.4.27
jaxlib: 0.4.27
numpy:  1.26.4
python: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cn-g022.server.mila.quebec', release='5.15.0-101-generic', version='#111-Ubuntu SMP Tue Mar 5 20:16:58 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri May 31 10:46:32 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:41:00.0 Off |                    0 |
| N/A   28C    P0              87W / 500W |    424MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   2029607      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+
```the 
zhixuan-lin commented 4 weeks ago

I found that I can get around this issue by constructing constants from jnp.arange(0, stop) in the kernel (not np.arange. Also the jnp.arange must start from zero and has a stride of one). However, the original issue still looks like a bug.

superbobry commented 4 weeks ago

Thanks for the update @zhixuan-lin! It does look like a bug, yeah. I will look into fixing this.

superbobry commented 2 weeks ago

I realized my fix is actually partial. In particular, it doesn't work when the kernel is vmapped. Will look into alternatives.