jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Flax output mismatch for multi-dimensional batch input on GPUs #15898

Open patel-zeel opened 1 year ago

patel-zeel commented 1 year ago

Description

Hi,

I am trying to run a simple MLP on A100 GPU with multi-dimensional batch inputs of shape (b1, b2, input_dim) and output shape (b1, b2, 1). Flax outputs when passing the entire input (b1, b2, input_dim) v/s passing a single input (1, 1, input_dim) iteratively are not matching. When I run the same code example on CPU or run the equivalent PyTorch version, it matches exactly. Please see the minimal code example, colab link and outputs in the issue below:

https://github.com/google/flax/issues/3084

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np

# Flax imports
import jax.random as jr
import jax.numpy as jnp
import flax.linen as nn

# PyTorch imports
import torch

# Common constants
b1, b2 = 2, 3
input_dim = 2
hidden_dim = 2
output_dim = 1
batch_shape = (b1, b2)

# Flax code
tiny_model = nn.Sequential([nn.Dense(hidden_dim), nn.Dense(output_dim)])
tiny_params = tiny_model.init(jr.PRNGKey(1234), jnp.ones((*batch_shape, input_dim)))

x = jr.normal(jr.PRNGKey(5678), (*batch_shape, input_dim))
batch_out = tiny_model.apply(tiny_params, x)

individual_out = np.zeros_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i:i+1, j:j+1, :] = tiny_model.apply(tiny_params, x[i:i+1, j:j+1, :])

print(f"Flax output match: {jnp.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

# PyTorch code
torch.manual_seed(1234)
model = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.Linear(hidden_dim, output_dim))
batch_out = model(torch.tensor(x.tolist()))

individual_out = torch.ones_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i, j, :] = model(torch.tensor(x[i:i+1, j:j+1, :].tolist())).squeeze()

print(f"PyTorch output match: {torch.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

No response

NVIDIA GPU info

Sat May  6 13:25:07 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   36C    P0    68W / 500W |      3MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:41:00.0 Off |                    0 |
| N/A   35C    P0    75W / 500W |  74505MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   33C    P0    63W / 500W |      3MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:C1:00.0 Off |                    0 |
| N/A   34C    P0    73W / 500W |   9913MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    1   N/A  N/A   1587948      C   ...vs/active_NILM/bin/python    74502MiB |
|    3   N/A  N/A   1568130      C   .../envs/torch_dt/bin/python     9910MiB |
+-----------------------------------------------------------------------------+
murphyk commented 1 year ago

I think this is just a numerical preision problem, not a bug. eg change the print statement to

print(f"Flax output match: {jnp.allclose(batch_out, individual_out, atol=1e-3)}")

and it says True on GPU (and CPU). Similarly,

m = jnp.max(batch_out - individual_out)
print(m) # 0.000106453896
patel-zeel commented 1 year ago

That's right, Dr. @murphyk. However, I guess it should exactly match given that precision (float32) is not changed between batch and individual versions of code. @mjsML may let us know his views on this from the domain perspective.