Flax output mismatch for multi-dimensional batch input on GPUs #15898

Open patel-zeel opened 1 year ago

patel-zeel commented 1 year ago



I am trying to run a simple MLP on A100 GPU with multi-dimensional batch inputs of shape (b1, b2, input_dim) and output shape (b1, b2, 1). Flax outputs when passing the entire input (b1, b2, input_dim) v/s passing a single input (1, 1, input_dim) iteratively are not matching. When I run the same code example on CPU or run the equivalent PyTorch version, it matches exactly. Please see the minimal code example, colab link and outputs in the issue below:


import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np

# Flax imports
import jax.random as jr
import jax.numpy as jnp
import flax.linen as nn

# PyTorch imports
import torch

# Common constants
b1, b2 = 2, 3
input_dim = 2
hidden_dim = 2
output_dim = 1
batch_shape = (b1, b2)

# Flax code
tiny_model = nn.Sequential([nn.Dense(hidden_dim), nn.Dense(output_dim)])
tiny_params = tiny_model.init(jr.PRNGKey(1234), jnp.ones((*batch_shape, input_dim)))

x = jr.normal(jr.PRNGKey(5678), (*batch_shape, input_dim))
batch_out = tiny_model.apply(tiny_params, x)

individual_out = np.zeros_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i:i+1, j:j+1, :] = tiny_model.apply(tiny_params, x[i:i+1, j:j+1, :])

print(f"Flax output match: {jnp.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

# PyTorch code
model = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.Linear(hidden_dim, output_dim))
batch_out = model(torch.tensor(x.tolist()))

individual_out = torch.ones_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i, j, :] = model(torch.tensor(x[i:i+1, j:j+1, :].tolist())).squeeze()

print(f"PyTorch output match: {torch.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88

Which accelerator(s) are you using?


Additional system info

murphyk commented 1 year ago

I think this is just a numerical preision problem, not a bug. eg change the print statement to

print(f"Flax output match: {jnp.allclose(batch_out, individual_out, atol=1e-3)}")

and it says True on GPU (and CPU). Similarly,

m = jnp.max(batch_out - individual_out)
print(m) # 0.000106453896
patel-zeel commented 1 year ago

That's right, Dr. @murphyk. However, I guess it should exactly match given that precision (float32) is not changed between batch and individual versions of code. @mjsML may let us know his views on this from the domain perspective.