google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

remat not bitwise identical #8702

Open trevorcai opened 2 years ago

trevorcai commented 2 years ago

Following up on this conversation on Twitter, here's a simple example of remat producing non-identical results on CPU.

import functools

import numpy as np
import jax
import jax.numpy as jnp

def hp_softmax(x):
  x = x.astype(np.float32)
  return jax.nn.softmax(x)

def loss_fn(x, remat_sm):
  x = x + 1
  fn = hp_softmax
  if remat_sm:
    fn = jax.remat(fn)
  sm = fn(x)
  return sm[0]

x = np.random.random((1024,)).astype(jnp.bfloat16)
jgf = jax.jit(jax.grad(loss_fn), static_argnums=(1,))
nr = jgf(x, False)
r = jgf(x, True)
np.sum(nr != r)
# 486

n.b. that commenting out either x.astype(np.float32) and x = x + 1 result in bitwise identical behavior, as does keeping x in high-precision for the input (either f32 or f64).

Without having fully parsed the compiled XLA, I assume the issue here is that the forward pass gets some nice excess precision through fusing x = x + 1 and x = x.astype(np.float32) while the XLA::Cond gadget prevents this on the rematerialized forward pass.

I am not experiencing any known bugs that can be traced to this issue.

mattjj commented 2 years ago

Thanks, this is a super useful data point / test case!

Let's set it at P3 for now until someone becomes blocked by this kind of issue. If/when that happens, we'll have a bit more info, thanks to this issue!