google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.93k stars 628 forks source link

Flax output mismatch for multi-dimensional batch input on GPUs #3084

Open patel-zeel opened 1 year ago

patel-zeel commented 1 year ago

Hi,

I am trying to run a simple MLP on A100 GPU with multi-dimensional batch inputs of shape (b1, b2, input_dim) and output shape (b1, b2, 1). Flax outputs when passing the entire input (b1, b2, input_dim) v/s passing a single input (1, 1, input_dim) iteratively are not matching. When I run the same code example on CPU or run the equivalent PyTorch version, it matches exactly. Please see the minimal code example, colab link and outputs below.

System information

Problem you have encountered:

Output mismatch while running the following minimal example on GPU:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np

# Flax imports
import jax.random as jr
import jax.numpy as jnp
import flax.linen as nn

# PyTorch imports
import torch

# Common constants
b1, b2 = 2, 3
input_dim = 2
hidden_dim = 2
output_dim = 1
batch_shape = (b1, b2)

# Flax code
tiny_model = nn.Sequential([nn.Dense(hidden_dim), nn.Dense(output_dim)])
tiny_params = tiny_model.init(jr.PRNGKey(1234), jnp.ones((*batch_shape, input_dim)))

x = jr.normal(jr.PRNGKey(5678), (*batch_shape, input_dim))
batch_out = tiny_model.apply(tiny_params, x)

individual_out = np.zeros_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i:i+1, j:j+1, :] = tiny_model.apply(tiny_params, x[i:i+1, j:j+1, :])

print(f"Flax output match: {jnp.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

# PyTorch code
torch.manual_seed(1234)
model = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.Linear(hidden_dim, output_dim))
batch_out = model(torch.tensor(x.tolist()))

individual_out = torch.ones_like(batch_out)
for i in range(b1):
    for j in range(b2):
        individual_out[i, j, :] = model(torch.tensor(x[i:i+1, j:j+1, :].tolist())).squeeze()

print(f"PyTorch output match: {torch.all(batch_out == individual_out)}")
display(batch_out.squeeze().tolist(), individual_out.squeeze().tolist())

What you expected to happen:

Output should match

Steps to reproduce:

Here is a colab link. Surprisingly it did not brake on Google colab on all random experiments. Luckily I found one combination where it fails. However the differences on A100 GPU system are way off compared to colab GPU. Here are the outputs on colab v/s on my A100 GPU system.

Outputs on colab GPU

Flax output match: False
[[0.13540136814117432, 0.6433075666427612, 0.031430892646312714],
 [0.6323085427284241, -1.3527320623397827, -1.2486413717269897]]
[[0.13540136814117432, 0.6433075666427612, 0.031430892646312714],
 [0.6323085427284241, -1.3527320623397827, -1.2486412525177002]]
PyTorch output match: True
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]

Outputs on colab CPU

WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Flax output match: True
[[0.13540135324001312, 0.6433075070381165, 0.031430892646312714],
 [0.6323084235191345, -1.3527319431304932, -1.2486411333084106]]
[[0.13540135324001312, 0.6433075070381165, 0.031430892646312714],
 [0.6323084235191345, -1.3527319431304932, -1.2486411333084106]]
PyTorch output match: True
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]

Output on A100 GPU system's GPU

Flax output match: False
[[0.13544875383377075, 0.6433806419372559, 0.031420525163412094],
 [0.6323312520980835, -1.3529661893844604, -1.2485347986221313]]
[[0.13540136814117432, 0.6433075666427612, 0.031430892646312714],
 [0.6323085427284241, -1.3527320623397827, -1.2486412525177002]]
PyTorch output match: True
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]

Output on A100 GPU system's CPU

2023-05-05 14:44:18.123180: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Flax output match: True
[[0.13540135324001312, 0.6433075070381165, 0.031430892646312714],
 [0.6323084235191345, -1.3527319431304932, -1.2486411333084106]]
[[0.13540135324001312, 0.6433075070381165, 0.031430892646312714],
 [0.6323084235191345, -1.3527319431304932, -1.2486411333084106]]
PyTorch output match: True
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]
[[0.37359875440597534, 0.11582076549530029, 0.650824785232544],
 [0.17924970388412476, 1.6812427043914795, 1.4242538213729858]]
cgarciae commented 1 year ago

Hey @patel-zeel, Flax doesn't have any platform specific code. Please post this issue on the JAX repo, they might be able to better help you solve the problem.