jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

Offloading in grad(scan(remat(fn, policy=offload))) results in XlaRuntimeError #24115

Open hbq1 opened 1 month ago

hbq1 commented 1 month ago

Description

The following code

import jax
import jax.ad_checkpoint
from jax import numpy as jnp

@jax.jit
def apply(params, x):

  def step(y, i):
    y = jnp.sin(y)

    y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')

    y = jnp.sin(y)
    return y, ()

  step = jax.remat(
      step,
      policy=jax.checkpoint_policies.save_and_offload_only_these_names(
          names_which_can_be_saved=(),
          names_which_can_be_offloaded=('save_remat',),
          offload_src='device',
          offload_dst='pinned_host',
      ),
  )

  y = jax.grad(lambda p: jax.lax.scan(step, p, x)[0].sum())(params)
  return y

params = jnp.ones([3])
x = jnp.ones([2, 3])

apply(params, x)

reproduces the error

XlaRuntimeError: UNIMPLEMENTED: Performing sub-chunk copy is not supported in async dynamic slice yet.

Error encountered while compiling %dynamic-slice-start = ((f32[2,3]{1,0:T(2,128)S(5)}, s32[]{:T(256)}, s32[]{:T(256)}), f32[1,3]{1,0:T(2,128)}, u32[]{:S(2)}, s32[]) dynamic-slice-start(f32[2,3]{1,0:T(2,128)S(5)} %get-tuple-element.237, s32[]{:T(256)} %select.6, s32[]{:T(256)} %constant.4..sunk.1), dynamic_slice_sizes={1,3}, metadata={op_name="jit(apply)/jit(main)/transpose(jvp(while))/body/dynamic_slice" source_file="<ipython-input-1-a56f72d33e52>" source_line=27}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}.

Error encountered while compiling %while.7 = (s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) while((s32[]{:T(256)}, f32[3]{0:T(256)}, f32[2,3]{1,0:T(2,128)S(5)}, f32[2,3]{1,0:T(2,128)}, s32[]{:T(256)}, /*index=5*/s32[]{:T(256)}, s32[]{:T(256)}) %tuple.50), condition=%wide.wide.wide.wide.region_3.91.clone.clone.clone, body=%wide.wide.wide.wide.region_2.67.clone.clone.clone.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.1
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (94c024adedcb53059c29d7c2d62982053b60e86a)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', ..., release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')
jaro-sevcik commented 3 weeks ago

The scan needs to be inside the remat. For instance:

import jax
import jax.ad_checkpoint
from jax import numpy as jnp

@jax.jit
def apply(params, x):

  def step(y, i):
    y = jnp.sin(y)

    y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')

    y = jnp.sin(y)
    return y, ()

  repeated_step = jax.remat(
      lambda p: jax.lax.scan(step, p, x),
      policy=jax.checkpoint_policies.save_and_offload_only_these_names(
          names_which_can_be_saved=(),
          names_which_can_be_offloaded=('save_remat',),
          offload_src='device',
          offload_dst='pinned_host',
      ),
  )

  y = jax.jit(jax.grad(lambda p: repeated_step(p)[0].sum()))(params)
  return y

params = jnp.ones([3])
x = jnp.ones([2, 3])

apply(params, x)
hbq1 commented 3 weeks ago

@jaro-sevcik your version of the code does a different thing: it rematerialises the whole loop instead of each step, which implements a different memory/compute trade-off.

In this particular example, it results in jax.grad(jax.remat(...)) pattern which is suboptimal (see the example in https://jax.readthedocs.io/en/latest/gradient-checkpointing.html#jax-checkpoint-fundamentals).