jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Potential memory issue on Apple Metal #19132

Closed milutter closed 6 months ago

milutter commented 10 months ago

Description

Evaluating a simple multi-layer-perceptron (MLP) implemented in flax on the same input data and parameters potentially yields non-deterministic outputs on the apple metal device when the function is NOT jitted. When the function is jitted the outputs of the MLP are deterministic and the problem disappears. I was able to verify that this problem is specific to apple metal, as on a linux system with an nvidia gpu, the problem does not occur (with the current jax version).

Empirically the problem frequency seems to be worse when the batch dimension is not of shape 2**n and n > 10. For example for batch dimensions of 2500 and 5000 the problems occurs frequently. Another empirical observation is that the values are not random but repeat themself. For example y[0, 0] is always one of m different numbers (empirically m \approx 3-4) but it is random which one of the m options ones get, which kind of hints into a memory problem.

It is debatable whether this problem is a jax, flax or apple metal plugin issue. I am happy to file this issue at a different location if preferred.

import jax
import numpy as np
from jax import numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    out_dims: int
    hidden_dims: int

    @nn.compact
    def __call__(self, x):
        h1 = nn.Dense(self.hidden_dims)(x)
        h2 = nn.Dense(self.hidden_dims)(h1)
        return nn.Dense(self.out_dims)(h2)

if __name__ == "__main__":
    key = jax.random.PRNGKey(42)

    n_dim = 2
    for n_samples in [512, 1024, 2048, 2500, 4096, 5000, 8192]:
        x = jax.random.uniform(key, shape=(n_samples, n_dim))

        network = MLP(hidden_dims=256, out_dims=2)
        params = network.init(key, x)

        # forward_fn = jax.jit(network.apply)
        forward_fn = network.apply

        y0 = forward_fn(params, x)
        n_error, max_error = [], []
        for i in range(20):
            yi = forward_fn(params, x)
            error = jnp.sum(jnp.abs(yi - y0), axis=-1)
            n_error.append(jnp.sum(error > 1e-6).item())
            max_error.append(error.max().item())

        print(f'\n# Samples = {n_samples}'
              f'\n# Error = {n_error}'
              f'\nmin/max error = {np.array(max_error).min():.2f} / {np.array(max_error).max():.2f}')

Output WITHOUT JIT => Non-Deterministic Output:

Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

# Samples = 512
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 1024
# Error = [0, 409, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 2.51

# Samples = 2048
# Error = [0, 0, 0, 814, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 2.76

# Samples = 2500
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 4096
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 5000
# Error = [4163, 3681, 4229, 3648, 4218, 3692, 4244, 3745, 4160, 3654, 4423, 3629, 4152, 3674, 4225, 3613, 4172, 3658, 4193, 3630]
min/max error = 2.31 / 3.64

# Samples = 8192
# Error = [0, 256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 1.42

Output WITH JIT => Everything works as expected

Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

# Samples = 512
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 1024
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 2048
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 2500
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 4096
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 5000
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

# Samples = 8192
# Error = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
min/max error = 0.00 / 0.00

What jax/jaxlib version are you using?

0.4.11, 0.4.11

Which accelerator(s) are you using?

Apple Metal

Additional system info?

Python 3.10, Numpy 1.26.0, Platform MacOS / Darwin

NVIDIA GPU info

No response

rajasekharporeddy commented 6 months ago

Hi @milutter

I executed the mentioned code with jax-metal 0.0.6 on a Macbook Pro with an M1 Pro chip to see if the reported issue persists. The code produces the same output regardless of using Just-In-Time (JIT) compilation.

Screenshot 2024-04-23 at 12 25 57 PM Screenshot 2024-04-23 at 12 26 26 PM Screenshot 2024-04-23 at 12 26 39 PM

Could you please verify with jax-metal 0.0.6 and confirm if the issue issue still persists.

Thank you.

jakevdp commented 6 months ago

Thanks for following up @rajasekharporeddy, I'm going to close the issue

milutter commented 6 months ago

Can confirm after upgrading jax-metal, jax as well as flax this error seems to be gone.