Closed zhixuan-lin closed 2 months ago
I found that I can get around this issue by constructing constants from jnp.arange(0, stop)
in the kernel (not np.arange
. Also the jnp.arange
must start from zero and has a stride of one). However, the original issue still looks like a bug.
Thanks for the update @zhixuan-lin! It does look like a bug, yeah. I will look into fixing this.
I realized my fix is actually partial. In particular, it doesn't work when the kernel is vmapped. Will look into alternatives.
I think I have a similar issue (https://github.com/google/jax/issues/23301), is there a better way of handling constants now?
Description
The following pallas kernel causes
ValueError: safe_zip() argument 2 is shorter than argument 1
:Detailed error log:
I trace the issue to
_hoist_consts_to_refs
adding the constants (the arrayindices
in the code) tojaxpr.invars
. Later inlower_jaxpr_to_triton_ir
jaxpr.invars
is zipped withblock_infos
. Sinceblock_infos
do not contain block information for the constants,jaxpr.invars
andblock_infos
have different lengths, which causes the error.Also if I use
vmap
as in the following:The error occurs earlier in
_pallas_call_batching_rule
:I do not know much about jaxpr so I'm not sure what I should do here. Any pointers are appreciated. Thanks!
System info (python version, jaxlib version, accelerator, etc.)
Following is system info. I've also tried
0.4.28
but got the same error.