Open patel-zeel opened 1 year ago
I think this is just a numerical preision problem, not a bug. eg change the print statement to
print(f"Flax output match: {jnp.allclose(batch_out, individual_out, atol=1e-3)}")
and it says True on GPU (and CPU). Similarly,
m = jnp.max(batch_out - individual_out)
print(m) # 0.000106453896
That's right, Dr. @murphyk. However, I guess it should exactly match given that precision (float32) is not changed between batch and individual versions of code. @mjsML may let us know his views on this from the domain perspective.
Description
Hi,
I am trying to run a simple MLP on A100 GPU with multi-dimensional batch inputs of shape (b1, b2, input_dim) and output shape (b1, b2, 1). Flax outputs when passing the entire input (b1, b2, input_dim) v/s passing a single input (1, 1, input_dim) iteratively are not matching. When I run the same code example on CPU or run the equivalent PyTorch version, it matches exactly. Please see the minimal code example, colab link and outputs in the issue below:
https://github.com/google/flax/issues/3084
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
No response
NVIDIA GPU info