Open ji8er opened 1 month ago
The underlying error is the following message:
'vector.shape_cast' op operand #0 must be vector of any type values, but got 'i32'
This error is happening because pl.program_id
returns a scalar which lives in a separate memory space from vectors (SMEM vs VMEM, see https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpu-and-its-memory-spaces). Because o_ref
is by default stored in VMEM, the compiler is trying to cast the program_id
to a vector of length-1 on the store operation, but Mosaic's shape cast operation vector.shape_cast
is only designed to translate vectors to other vectors, not scalars to vectors.
Ideally, Pallas ops should work gracefully regardless of whether the inputs are in SMEM/VMEM, but we don't have this implemented yet for all cases. We're also working on improving the error messages since these are quite difficult to parse currently and requires underlying knowledge of the Mosaic compiler.
There's a few ways you can work around this while waiting for an upstream fix.
One solution is to place o_ref
into SMEM as follows:
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
def iota(size: int):
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
grid=(size,)
)
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid_spec=grid_spec)()
iota(8)
You could also do the more awkward method of reshaping o_ref
to (size, 1) and using a reshape. By explicitly reshaping program_id
to a vector, this avoids having the store operation implicitly attempt the shape cast.
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i, :] = jnp.reshape(i, (1,))
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size, 1), jnp.int32),
grid=(size,), debug=True)()
iota(8)
Thanks for the detailed comment @justinjfu !
The explicit separation of SMEM seems nice for now.
Description
Code executed:
Stack Trace:
System info (python version, jaxlib version, accelerator, etc.)