google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.9k stars 2.73k forks source link

pallas simple ``pl.program_id()`` example not working #22817

Open ji8er opened 1 month ago

ji8er commented 1 month ago

Description

Code executed:

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

def iota(size: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
                        grid=(size,))()
iota(8)

Stack Trace:

---------------------------------------------------------------------------
MLIRError                                 Traceback (most recent call last)
[google3/third_party/py/jax/_src/pallas/mosaic/lowering.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping, name, for_verification)
    681   try:
--> 682     body.func_op.verify()
    683   except Exception as e:

MLIRError: Verification failed:
error: "/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (8,), ())], [*]),))]"(callsite("iota_kernel"("<ipython-input-18-1144262d6dee>":3:2) at callsite("iota"("<ipython-input-18-1144262d6dee>":8:9) at callsite("<module>"("<ipython-input-18-1144262d6dee>":11:0) at callsite("InteractiveShell.run_code"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) at callsite("InteractiveShell.run_ast_nodes"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3012:19) at callsite("InteractiveShell.run_cell"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":2901:16) at callsite("IPythonKernel.do_execute"("third_party/py/IPython/v3_2_3/kernel/zmq/ipkernel.py":181:12) at callsite("Kernel.execute_request"("third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py":361:24) at callsite("ColabKernel.execute_request"("research/colab/notebook/colab_kernel.py":240:4) at "Kernel.dispatch_shell"("third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py":213:16))))))))))): 'vector.shape_cast' op operand #0 must be vector of any type values, but got 'i32'
 note: "/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (8,), ())], [*]),))]"(callsite("iota_kernel"("<ipython-input-18-1144262d6dee>":3:2) at callsite("iota"("<ipython-input-18-1144262d6dee>":8:9) at callsite("<module>"("<ipython-input-18-1144262d6dee>":11:0) at callsite("InteractiveShell.run_code"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) at callsite("InteractiveShell.run_ast_nodes"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3012:19) at callsite("InteractiveShell.run_cell"("third_party/py/IPython/v3_2_3/core/interactiveshell.py":2901:16) at callsite("IPythonKernel.do_execute"("third_party/py/IPython/v3_2_3/kernel/zmq/ipkernel.py":181:12) at callsite("Kernel.execute_request"("third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py":361:24) at callsite("ColabKernel.execute_request"("research/colab/notebook/colab_kernel.py":240:4) at "Kernel.dispatch_shell"("third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py":213:16))))))))))): see current operation: %3 = "vector.shape_cast"(%arg0) : (i32) -> vector<1xi32>

The above exception was the direct cause of the following exception:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
<embedded module '_launcher'> in run_filename_from_loader_as_main()

<embedded module '_launcher'> in _run_code_in_main()

31 frames
[google3/learning/deepmind/dm_python/dm_notebook3_tpu.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in <module>()
     36   else:
---> 37     app.run(main, flags_parser=PrepareAppAndParseFlags)

[google3/third_party/py/absl/app.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in run()
    483     try:
--> 484       _run_main(main, args)
    485     except UsageError as error:

[google3/third_party/py/absl/app.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in _run_main()
    403   else:
--> 404     sys.exit(main(argv))
    405 

[google3/learning/deepmind/dm_python/dm_notebook3_tpu.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in main()
     25 def main(argv):
---> 26   return notebook.main(argv)
     27 

[google3/research/colab/notebook/notebook.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in main()
    531   if len(sys.argv) > 1 and sys.argv[1] == 'kernel':
--> 532     return kernel_app.RunForever(kernel_autoload_modules)
    533 

[google3/research/colab/notebook/kernel_app.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in RunForever()
    238 
--> 239   kernelapp_instance.start()
    240 

[google3/third_party/py/IPython/v3_2_3/kernel/zmq/kernelapp.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in start()
    372         try:
--> 373             ioloop.IOLoop.instance().start()
    374         except KeyboardInterrupt:

[google3/third_party/py/tornado/v4_5/ioloop.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in start()
    887                         fd_obj, handler_func = self._handlers[fd]
--> 888                         handler_func(fd_obj, events)
    889                     except (OSError, IOError) as e:

[google3/third_party/py/tornado/v4_5/stack_context.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in null_wrapper()
    276                 _state.contexts = cap_contexts[0]
--> 277                 return fn(*args, **kwargs)
    278             finally:

[google3/third_party/py/zmq/eventloop/zmqstream.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in _handle_events()
    577             if zmq_events & zmq.POLLIN and self.receiving():
--> 578                 self._handle_recv()
    579                 if not self.socket:

[google3/third_party/py/zmq/eventloop/zmqstream.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in _handle_recv()
    606                 callback = self._recv_callback
--> 607                 self._run_callback(callback, msg)
    608 

[google3/third_party/py/zmq/eventloop/zmqstream.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in _run_callback()
    556             # inside our blanket exception handler rather than outside.
--> 557             callback(*args, **kwargs)
    558         except Exception:

[google3/third_party/py/tornado/v4_5/stack_context.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in null_wrapper()
    276                 _state.contexts = cap_contexts[0]
--> 277                 return fn(*args, **kwargs)
    278             finally:

[google3/third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in dispatcher()
    251             def dispatcher(msg):
--> 252                 return self.dispatch_shell(stream, msg)
    253             return dispatcher

[google3/third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in dispatch_shell()
    212             try:
--> 213                 handler(stream, idents, msg)
    214             except Exception:

[google3/research/colab/notebook/colab_kernel.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in execute_request()
    239         )
--> 240     super().execute_request(stream, ident, parent)
    241 

[google3/third_party/py/IPython/v3_2_3/kernel/zmq/kernelbase.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in execute_request()
    360 
--> 361         reply_content = self.do_execute(code, silent, store_history,
    362                                         user_expressions, allow_stdin)

[google3/third_party/py/IPython/v3_2_3/kernel/zmq/ipkernel.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in do_execute()
    180         try:
--> 181             shell.run_cell(code, store_history=store_history, silent=silent)
    182         except:

[google3/third_party/py/IPython/v3_2_3/core/interactiveshell.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in run_cell()
   2900                 interactivity = "none" if silent else self.ast_node_interactivity
-> 2901                 self.run_ast_nodes(code_ast.body, cell_name,
   2902                    interactivity=interactivity, compiler=compiler, result=result)

[google3/third_party/py/IPython/v3_2_3/core/interactiveshell.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in run_ast_nodes()
   3011                 code = compiler(mod, cell_name, "single")
-> 3012                 if self.run_code(code, result):
   3013                     return True

[google3/third_party/py/IPython/v3_2_3/core/interactiveshell.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in run_code()
   3065                 #rprint('Running code', repr(code_obj)) # dbg
-> 3066                 exec(code_obj, self.user_global_ns, self.user_ns)
   3067             finally:

[<ipython-input-18-1144262d6dee>](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in <module>()
     10                         grid=(size,))()
---> 11 iota(8)

[<ipython-input-18-1144262d6dee>](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in iota()
      7   y = jnp.arange(8)
----> 8   return pl.pallas_call(iota_kernel,
      9                         out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),

[google3/third_party/py/jax/_src/pallas/pallas_call.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in wrapped()
   1105     index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
-> 1106     out_flat = pallas_call_p.bind(
   1107         *dynamic_grid_bounds, *index_args, *rest_args,

JaxStackTraceBeforeTransformation: jax._src.pallas.mosaic.lowering.LoweringException: Body failed to verify: "func.func"() <{function_type = (i32, memref<8xi32, #tpu.memory_space<vmem>>) -> (), sym_name = "main"}> ({
^bb0(%arg0: i32, %arg1: memref<8xi32, #tpu.memory_space<vmem>>):
  %0 = "arith.index_cast"(%arg0) : (i32) -> index
  %1 = "vector.load"(%arg1, %0) : (memref<8xi32, #tpu.memory_space<vmem>>, index) -> vector<1xi32>
  %2 = "vector.shape_cast"(%1) : (vector<1xi32>) -> vector<i32>
  %3 = "vector.shape_cast"(%arg0) : (i32) -> vector<1xi32>
  "vector.store"(%3, %arg1, %0) : (vector<1xi32>, memref<8xi32, #tpu.memory_space<vmem>>, index) -> ()
  "func.return"() : () -> ()
}) : () -> ()
.
This is an internal error. Please report a bug at: https://github.com/google/jax/issues/new?assignees=sharadmv.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

LoweringException                         Traceback (most recent call last)
[<ipython-input-18-1144262d6dee>](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in <module>()
      9                         out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
     10                         grid=(size,))()
---> 11 iota(8)

[<ipython-input-18-1144262d6dee>](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in iota(size)
      6   x = jnp.arange(8)
      7   y = jnp.arange(8)
----> 8   return pl.pallas_call(iota_kernel,
      9                         out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
     10                         grid=(size,))()

[google3/third_party/py/jax/_src/pallas/pallas_call.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in _pallas_call_lowering(ctx, interpret, *in_nodes, **params)
    929     )
    930 
--> 931   return mlir.lower_per_platform(ctx, "pallas_call",
    932                                  dict(cpu=cpu_lowering,
    933                                       tpu=tpu_lowering,

[google3/third_party/py/jax/_src/pallas/pallas_call.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in tpu_lowering(ctx, *in_nodes, **params)
    911     if mosaic_tpu_backend is None:
    912       raise _unsupported_lowering_error("tpu")
--> 913     return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
    914         ctx, *in_nodes, **params
    915     )

[google3/third_party/py/jax/_src/pallas/mosaic/pallas_call_registration.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in pallas_call_tpu_lowering_rule(***failed resolving arguments***)
    108           dimension_semantics=dimension_semantics, mesh=mesh,
    109           for_verification=for_verification)
--> 110   mosaic_module, extra_args = lower_module(for_verification=False)
    111   if debug:
    112     print(mosaic_module)

[google3/third_party/py/jax/_src/pallas/mosaic/pallas_call_registration.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in lower_module(for_verification)
    104     with mlir_ctx, ir.Location.unknown(mlir_ctx):
    105       dimension_semantics = mosaic_params.get("dimension_semantics", None)
--> 106       return lowering.lower_jaxpr_to_module(
    107           ctx, mlir_ctx, grid_mapping, jaxpr,
    108           dimension_semantics=dimension_semantics, mesh=mesh,

[google3/third_party/py/jax/_src/pallas/mosaic/lowering.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in lower_jaxpr_to_module(lowering_context, ctx, grid_mapping, jaxpr, dimension_semantics, mesh, for_verification)
    501   m = ir.Module.create()
    502   sym_tab = ir.SymbolTable(m.operation)
--> 503   func_op = lower_jaxpr_to_func(
    504       ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
    505       name="main", for_verification=for_verification,

[google3/third_party/py/jax/_src/pallas/mosaic/lowering.py](https://colab.corp.google.com/drive/1CWuh1m9k9fwdF3irJa1NGk8ttFKB1B22#) in lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping, name, for_verification)
    682     body.func_op.verify()
    683   except Exception as e:
--> 684     raise LoweringException(
    685         f"Body failed to verify: {body.func_op}.\nThis is an internal error."
    686         " Please report a bug at:"

LoweringException: Body failed to verify: "func.func"() <{function_type = (i32, memref<8xi32, #tpu.memory_space<vmem>>) -> (), sym_name = "main"}> ({
^bb0(%arg0: i32, %arg1: memref<8xi32, #tpu.memory_space<vmem>>):
  %0 = "arith.index_cast"(%arg0) : (i32) -> index
  %1 = "vector.load"(%arg1, %0) : (memref<8xi32, #tpu.memory_space<vmem>>, index) -> vector<1xi32>
  %2 = "vector.shape_cast"(%1) : (vector<1xi32>) -> vector<i32>
  %3 = "vector.shape_cast"(%arg0) : (i32) -> vector<1xi32>
  "vector.store"(%3, %arg1, %0) : (vector<1xi32>, memref<8xi32, #tpu.memory_space<vmem>>, index) -> ()
  "func.return"() : () -> ()
}) : () -> ()
.
This is an internal error. Please report a bug at: https://github.com/google/jax/issues/new?assignees=sharadmv.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32
jaxlib: 0.4.32
numpy:  1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (84658fb82b67fc22ecba1560d0cddd09f9104178)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='d00ef52addb6d6c9-526d6bea6d8.borgtask.google.com', release='5.10.0-smp-1102.57.0.0', version='#1 [v5.10.0-1102.57.0.0] SMP @1719966309', machine='x86_64')
justinjfu commented 1 month ago

Explanation of the error

The underlying error is the following message: 'vector.shape_cast' op operand #0 must be vector of any type values, but got 'i32'

This error is happening because pl.program_id returns a scalar which lives in a separate memory space from vectors (SMEM vs VMEM, see https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpu-and-its-memory-spaces). Because o_ref is by default stored in VMEM, the compiler is trying to cast the program_id to a vector of length-1 on the store operation, but Mosaic's shape cast operation vector.shape_cast is only designed to translate vectors to other vectors, not scalars to vectors.

Ideally, Pallas ops should work gracefully regardless of whether the inputs are in SMEM/VMEM, but we don't have this implemented yet for all cases. We're also working on improving the error messages since these are quite difficult to parse currently and requires underlying knowledge of the Mosaic compiler.

Temporary Solutions

There's a few ways you can work around this while waiting for an upstream fix.

One solution is to place o_ref into SMEM as follows:

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

def iota(size: int):
  grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[],
            out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
            grid=(size,)
        )
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
                        grid_spec=grid_spec)()
iota(8)

You could also do the more awkward method of reshaping o_ref to (size, 1) and using a reshape. By explicitly reshaping program_id to a vector, this avoids having the store operation implicitly attempt the shape cast.

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i, :] = jnp.reshape(i, (1,))

def iota(size: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size, 1), jnp.int32),
                        grid=(size,), debug=True)()
iota(8)
ji8er commented 1 month ago

Thanks for the detailed comment @justinjfu !

The explicit separation of SMEM seems nice for now.