Open muchanem opened 4 weeks ago
Bumping this with some more testing/further narrowing down. Here's a snippet that is much shorter/hopefully easier to understand. Both check grads always fail (I suspect that this is just a too small tolerance). But importantly, the inclusion of any ReLU (leaky or otherwise) causes the JIT'd absolute difference to be a factor of 2 or more bigger than non-JIT'd, in my last test JIT had a max absolute difference of 0.10245637
vs 0.04284701
without JIT (notably without an activation, they remain the same but out of tolerance).
latent_dim = 2**11
input_dim = 3072
initializer = jax.nn.initializers.he_uniform()
encoder = initializer(jax.random.key(0), (latent_dim, input_dim), jnp.float32)
decoder = encoder.T
def encode(x):
codes = encoder @ x
return jax.nn.relu(codes)
#return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(encoder.shape[0]))
def top_k_decode(top_k_indices, top_k_values):
decoder_weights = (decoder / jnp.linalg.norm(decoder, axis=-1, keepdims=True)).T
# top_k_indices is now 1D after vmap, so we don't need [:, :, None]
selected_decoder_weights = decoder_weights[top_k_indices]
# Adjust the sum operation to match the new shape
return jnp.sum(top_k_values[:, None] * selected_decoder_weights, axis=0)
def fwd_pass(batch: jnp.ndarray):
top_level_latent_codes = jax.vmap(encode)(batch)
top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
x_hat = jax.vmap(top_k_decode)(top_k_indices, top_k_values)
return x_hat
def recon_loss(batch: jnp.ndarray):
x_hat = fwd_pass(batch)
return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))
example_batch = jax.random.normal(jax.random.key(42), (4096,3072))
example_batch = example_batch/jnp.linalg.norm(example_batch, axis=-1, keepdims=True)
from jax.test_util import check_grads
check_grads(jax.jit(recon_loss), args=(example_batch,), order=1)
check_grads(recon_loss, args=(example_batch,), order=1)
P.S. tested this on a TPU colab and jitting/adding ReLU doesn't make a difference--the grads are "out of tolerance" in all cases by the same amount
I think the issue here is that your gradients are very close to zero, so very small absolute deviations become relatively large relative variations:
grad_jit_result = jax.grad(jax.jit(recon_loss))(example_batch)
grad_result = jax.grad(recon_loss)(example_batch)
np.testing.assert_allclose(grad_jit_result, grad_result)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2554069 / 12582912 (20.3%)
Max absolute difference: 7.275958e-12
Max relative difference: 0.3
x: array([[-6.417103e-06, 1.638319e-05, -2.042344e-06, ..., 2.303527e-05,
2.817694e-06, -6.996353e-06],
[-3.996819e-06, -8.903306e-06, 7.238759e-08, ..., -2.694251e-06,...
y: array([[-6.417103e-06, 1.638319e-05, -2.042344e-06, ..., 2.303527e-05,
2.817694e-06, -6.996353e-06],
[-3.996818e-06, -8.903306e-06, 7.238714e-08, ..., -2.694251e-06,...
The gradients match in an absolute sense to 1 part in 10^12, but in a relative sense only match to 1 part in 3, meaning the mismatched gradients are something like 3E-13
vs. 1E-12
, so both basically zero. This large relative difference is likely causing check_grads
to report an error, but it's not an error that will be numerically significant in the overall computation.
That makes sense, but I still encounter the bug where JITing causes optimization to work in one case and totally fail in the other (and this behavior doesn't happen on a TPU platform, which makes me think there's some XLA bug, not a code bug). Removed the leaky ReLU since it doesn't do anything except add a threshold when there's a top-k activation and after that change JIT'd training on GPU still fails and removing the JIT (or turning off algsimp) causes it to succeed (e.g. at step 400 a JIT'd trainer still has loss above 1 and an eager optimization is ~0.8)
Hmm, that's strange indeed. In general we don't expect JIT-compiled versions of functions to have bitwise-identical outputs to the non-compiled versions: any time you rearrange or fuse floating point operations, you'll change the details of the numerics, but it should be to within normal floating-point precision.
Can you try running it again while setting jax_default_matmul_precision=highest
?
That didn't help. Here's the code as it is being run right now (and the first 1000 steps of training)--sorry that it's no longer fully self-contained. Contrary to what I said earlier, substituting the leaky_offset_relufor either a normal relu or no activation fixes the issue (my other comment was testing a more complicated version of this model 🥲). There's no actual negative slope in how the leaky offset ReLU is called, just a thresholded ReLU; I suspect that a decent number of the activations end up below the threshold at the beginning of training and that causes issues.
import jax
import jax.numpy as jnp
import equinox as eqx
import treescope
from jax.experimental import sparse
treescope.basic_interactive_setup()
from typing import Tuple
import optax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
jax.config.update('jax_default_matmul_precision', "highest")
def leaky_offset_relu(x, negative_slope=1e-2, offset=0):
return jnp.where(x >= offset, x, negative_slope * x)
class Autoencoder(eqx.Module):
encoder: jnp.ndarray
decoder: jnp.ndarray
bias: jnp.ndarray
use_bias: bool
def __init__(self, latent_dim: int, input_dim: int, use_bias: bool = True, key=None):
initializer = jax.nn.initializers.he_uniform()
self.encoder = initializer(key, (latent_dim, input_dim), jnp.float32)
self.decoder = self.encoder.T
self.bias = jnp.zeros(input_dim) if use_bias else None
self.use_bias = use_bias
def encode(self, x):
x = x - self.bias if self.use_bias else x
codes = self.encoder @ x
#return codes
return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(self.encoder.shape[0]))
def top_k_decode(self, top_k_indices, top_k_values):
decoder_weights = self.get_decoder()
# top_k_indices is now 1D after vmap, so we don't need [:, :, None]
selected_decoder_weights = decoder_weights[:, top_k_indices]
return selected_decoder_weights @ top_k_values + (self.bias if self.use_bias else 0)
def get_decoder(self):
return self.decoder / jnp.linalg.norm(self.decoder, axis=0, keepdims=True)
def get_encoder(self):
return self.encoder
def __call__(self, x):
z = self.encode(x)
return self.decode(z)
def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
top_level_latent_codes = jax.vmap(model.encode)(batch)
top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
return x_hat, top_level_latent_codes
def recon_loss(x_hat: jnp.ndarray, batch: jnp.ndarray):
return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))
@eqx.filter_value_and_grad
def loss_fn(model: Autoencoder, batch: jnp.ndarray, l1_penalty: float, ortho_penalty: float) -> float:
x_hat, top_level_latent_codes = fwd_pass(model, batch)
return recon_loss(x_hat, batch)
def update_model(model, grads, opt_state, optimizer):
updates, new_opt_state = optimizer.update(grads, opt_state)
new_model = eqx.apply_updates(model, updates)
return new_model, new_opt_state
@eqx.filter_jit
def train_step(model: Autoencoder, batch: jnp.ndarray, opt_state, l1_penalty: float, ortho_penalty: float, optimizer) -> Tuple[Autoencoder, optax.OptState, float]:
loss, grads = loss_fn(model, batch, l1_penalty, ortho_penalty)
model, opt_state = update_model(model, grads, opt_state, optimizer)
return model, opt_state, loss
# omitted dataloader stuff
model = Autoencoder(
input_dim=3072,
latent_dim=2**12,
use_bias=False,
key=jax.random.key(0)
)
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
l1_penalty = 1e-3
ortho_penalty = 1e-2
num_steps = 10000
for step in range(num_steps):
batch = next(train_loader)
model, opt_state, loss = train_step(model, batch, opt_state, l1_penalty, ortho_penalty, optimizer)
if step % 50 == 0:
print(f"Step {step}, Loss: {loss:.4f}")
No JIT
Step 0, Loss: 0.9761
Step 50, Loss: 0.9737
Step 100, Loss: 0.9722
Step 150, Loss: 0.9705
Step 200, Loss: 0.9687
Step 250, Loss: 0.9664
Step 300, Loss: 0.9630
Step 350, Loss: 0.9590
Step 400, Loss: 0.9531
Step 450, Loss: 0.9460
Step 500, Loss: 0.9373
Step 550, Loss: 0.9309
Step 600, Loss: 0.9174
Step 650, Loss: 0.9082
Step 700, Loss: 0.8965
Step 750, Loss: 0.8895
Step 800, Loss: 0.8821
Step 850, Loss: 0.8753
Step 900, Loss: 0.8647
Step 950, Loss: 0.8614
Step 1000, Loss: 0.8555
JIT on
Step 0, Loss: 1.0394
Step 50, Loss: 1.0389
Step 100, Loss: 1.0383
Step 150, Loss: 1.0381
Step 200, Loss: 1.0375
Step 250, Loss: 1.0372
Step 300, Loss: 1.0372
Step 350, Loss: 1.0367
Step 400, Loss: 1.0362
Step 450, Loss: 1.0357
Step 500, Loss: 1.0356
Step 550, Loss: 1.0353
Step 600, Loss: 1.0350
Step 650, Loss: 1.0346
Step 700, Loss: 1.0342
Step 750, Loss: 1.0341
Step 800, Loss: 1.0339
Step 850, Loss: 1.0333
Step 900, Loss: 1.0333
Step 950, Loss: 1.0331
Step 1000, Loss: 1.0325
Hi, what is happening is that the layout of top_level_latent_codes is changing the top_k algorithm and it looks like to me that there are a lot of near-zero values in the top_level_latent_codes and the top_k algorithm is struggling to choose the top_k consistently. If I add a layout constraint into your code (or return top_level_latent_codes as an aux output) then everything becomes consistent between jit and without_jit.
def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
top_level_latent_codes = jax.vmap(model.encode)(batch)
custom_dll = jax._src.layout.DeviceLocalLayout(major_to_minor=(0, 1)) # switch to (1, 0) to get 'jit_loss'
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
top_level_latent_codes = jax.lax.with_sharding_constraint(top_level_latent_codes, jax._src.layout.Layout(custom_dll, s))
top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8) # Top-k gating
x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
return x_hat, top_level_latent_codes, top_k_indices
While the layout does make them match, I suspect that this implies that the choice of the top_k in the case of ties or near zeros really matters and you should do something smarter here rather than just forcing it to choose a particular top_k algorithm via constraining the layouts.
Slight update on the usage of layout API: Use the experimental endpoint: from jax.experimental.layout import Layout, DeviceLocalLayout
This makes sense, but I'm still curious as to why this happens with leaky_offset_relu
but not otherwise. The offset means that with leaky_offset_relu
turned on, there's no values in top_level_latent_codes below ~.03. With no offset (i.e. regular ReLU or no activation), values go as low as 1E-5 but the bad training doesn't happen (sort of the opposite of what you'd expect if near zeroes are the issue). There's not any ties, nor any near ties (the mean distance between the activation values among the nonzero elements is ~0.01). The only other difference between the various activations is that there's only ~340 nonzero elements in the offset case and ~2048 without an offset. All these numbers are from step 10 of optimization, but the same is true at initialization.
You can debug these things by adding an aux output and print the intermediate results: @functools.partial(eqx.filter_value_and_grad, has_aux=True)
I think the simple explanation is that top_k_indices with relu + jit are ordered like so:
Array([[1793, 1338, 0, ..., 3, 4, 5],
[1074, 0, 1, ..., 4, 5, 6],
[ 0, 1, 2, ..., 5, 6, 7],
...,
[ 0, 1, 2, ..., 5, 6, 7],
[ 0, 1, 2, ..., 5, 6, 7],
[ 591, 1829, 0, ..., 3, 4, 5]], dtype=int32)
And everything else (nojit relu, jit norelu, nojit norelu) is more random like so:
Array([[ 985, 1273, 579, ..., 386, 1756, 390],
[1213, 402, 1652, ..., 59, 1665, 597],
[1275, 1131, 1073, ..., 1615, 1034, 2030],
...,
[ 570, 1460, 101, ..., 0, 1, 2],
[1546, 1519, 1748, ..., 1644, 902, 1482],
[1059, 430, 1778, ..., 732, 1034, 344]], dtype=int32)
Printing the intermediate values for leaky_relu, there are a lot of zeros in the results. Anyways, looking more closely, it might actually be a bug in how XLA does top-k. I'll look into that.
Thanks! Another thing I don't understand is why major-minor ordering helps regardless of batch size? Even if the batch dimension is larger than the latent dimension, your layout fix works--but isn't the batch dimension in that case major and latent minor? Is it just that the vmap means that the batch dimension will always be treated as the minor dimension since the ops are vectorized across it? (these runs were before I was adding in the aux output)
Was making modifications to code today and noticed that adding the line codes += 0
re-introduces the buggy top-k optimization
This should be fixed with: https://github.com/openxla/xla/commit/474036b950cbb934f09f5485633997fb379e6364 . This should show up in the nightlies soon: https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation .
Description
The below snippet causes drastically different results if the loss function is JIT'd vs not compiled. This results in a bad optimization where the version with the JIT'd loss function never converges. This error is probably with the calculation of the gradients, JITing any of the individual components doesn't cause the issue. Additionally, this bug only appears on systems with GPUs, in a TPU environment the issue doesn't appear. The bug is unrelated to my use of Equinox, the bug was originally observed in Flax. The flag
--xla_disable_hlo_passes=algsimp
will prevent the bug from appearing. I'm unsure what the jax equivalent of--tf_xla_parallel_checking
is so that I can compare the XLA-HLO computational graphs and narrow down the bug further. My guess would be something to do with the top-k gating in the decoder's interaction with jit'd grad (a version of this without a sparsity aware decoder didn't have this bug).System info (python version, jaxlib version, accelerator, etc.)
This bug has been observed on the latest jax version with an A100 as well. This is the colab environment where this isolated example was created