Following up on this conversation on Twitter, here's a simple example of remat producing non-identical results on CPU.
import functools
import numpy as np
import jax
import jax.numpy as jnp
def hp_softmax(x):
x = x.astype(np.float32)
return jax.nn.softmax(x)
def loss_fn(x, remat_sm):
x = x + 1
fn = hp_softmax
if remat_sm:
fn = jax.remat(fn)
sm = fn(x)
return sm[0]
x = np.random.random((1024,)).astype(jnp.bfloat16)
jgf = jax.jit(jax.grad(loss_fn), static_argnums=(1,))
nr = jgf(x, False)
r = jgf(x, True)
np.sum(nr != r)
# 486
n.b. that commenting out either x.astype(np.float32) and x = x + 1 result in bitwise identical behavior, as does keeping x in high-precision for the input (either f32 or f64).
Without having fully parsed the compiled XLA, I assume the issue here is that the forward pass gets some nice excess precision through fusing x = x + 1 and x = x.astype(np.float32) while the XLA::Cond gadget prevents this on the rematerialized forward pass.
I am not experiencing any known bugs that can be traced to this issue.
Thanks, this is a super useful data point / test case!
Let's set it at P3 for now until someone becomes blocked by this kind of issue. If/when that happens, we'll have a bit more info, thanks to this issue!
Following up on this conversation on Twitter, here's a simple example of remat producing non-identical results on CPU.
n.b. that commenting out either
x.astype(np.float32)
andx = x + 1
result in bitwise identical behavior, as does keepingx
in high-precision for the input (either f32 or f64).Without having fully parsed the compiled XLA, I assume the issue here is that the forward pass gets some nice excess precision through fusing
x = x + 1
andx = x.astype(np.float32)
while the XLA::Cond gadget prevents this on the rematerialized forward pass.I am not experiencing any known bugs that can be traced to this issue.