Open hbq1 opened 1 month ago
The scan
needs to be inside the remat
. For instance:
import jax
import jax.ad_checkpoint
from jax import numpy as jnp
@jax.jit
def apply(params, x):
def step(y, i):
y = jnp.sin(y)
y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')
y = jnp.sin(y)
return y, ()
repeated_step = jax.remat(
lambda p: jax.lax.scan(step, p, x),
policy=jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=(),
names_which_can_be_offloaded=('save_remat',),
offload_src='device',
offload_dst='pinned_host',
),
)
y = jax.jit(jax.grad(lambda p: repeated_step(p)[0].sum()))(params)
return y
params = jnp.ones([3])
x = jnp.ones([2, 3])
apply(params, x)
@jaro-sevcik your version of the code does a different thing: it rematerialises the whole loop instead of each step, which implements a different memory/compute trade-off.
In this particular example, it results in jax.grad(jax.remat(...))
pattern which is suboptimal (see the example in https://jax.readthedocs.io/en/latest/gradient-checkpointing.html#jax-checkpoint-fundamentals).
Description
The following code
reproduces the error
System info (python version, jaxlib version, accelerator, etc.)