Open Lime-Cakes opened 4 months ago
matmul tutorial now covers a bit (commit https://github.com/google/jax/commit/39ec5dacb44c01c130f1c2ab929ff50930a44235 ) about scratch_shapes. In this context 'scratch' refers to a temporary storage space (like scratchpad memory). I believe it works by allocating a small chunk of space in VMEM, which kernel have access to, used in ways much like those _ref.
Seems like all kernel get access to same scratch memory. Unsure how that works when megacore (tpu v4) is involved.
Circling back on this: We've added a new TPUCompilerParams class which has an (incomplete) set of documented options https://github.com/google/jax/pull/23127. We're still auditing which flags are stable enough to surface to the public API (including the flag that you mentioned) but we will add them here over time.
Also to specifically answer your questions on scratch and megacore (we'll be adding this to the tutorials):
Scratch does work as a "temporary" scratchpad memory as you mentioned - you could also use an output ref in the same way. For megacore configs, each core has it's own VMEM but shares HBM, so it's generally safe to copy from HBM to VMEM but you will need to be careful copying from VMEM back into HBM. (For remote communication, we also recommend issuing remote DMAs from a single core only).
The main use for pallas is for tpu, but right now, a lot of options related to TPU compiler is left undocumented and require guesswork.
Documentation for
PrefetchScalarGridSpec
is incomplete, lacking information about scratch_shapes. jax.experimental.pallas.tpu have no documentation at all. pallas_call's compiler_params also lacks documentation beyond dimension_semantics. It seems possible to pass flags to TPU compiler as seen below, but there's no doc on what flags do, and no list of flags.https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L1071
With compiler being closed source & no doc available, creating efficient kernel for tpu is rather difficult.