jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Getting different results from CPU vs. CUDA backend #22382

Open dionhaefner opened 3 months ago

dionhaefner commented 3 months ago

Description

I noticed in Veros that we would sometimes observe solver divergences on GPU, but not on CPU. I've isolated the problem to a kernel that gives different results based on the backend across several different machines.

Reproducer:

import sys
import pickle
from dataclasses import dataclass
import jax
import jax.numpy as npx

def update(a, idx, value):
    return a.at[idx].set(value)

class At:
    def __getitem__(self, item):
        return item

at = At()

def calc_topo_kernel(vs):
    land_mask = vs.kbot > 0
    ks = npx.arange(vs.maskT.shape[2])[npx.newaxis, npx.newaxis, :]

    vs.maskT = update(vs.maskT, at[...], npx.logical_and(land_mask[..., npx.newaxis], vs.kbot[..., npx.newaxis] - 1 <= ks))
    vs.maskU = update(vs.maskT, at[:-1, :, :], npx.logical_and(vs.maskT[:-1, :, :], vs.maskT[1:, :, :]))
    vs.maskV = update(vs.maskT, at[:, :-1], npx.logical_and(vs.maskT[:, :-1], vs.maskT[:, 1:]))

    return dict(
        maskT=vs.maskT,
        maskU=vs.maskU,
        maskV=vs.maskV,
    )

with open("jaxbug.pkl", "rb") as f:
    statedict = pickle.load(f)

@dataclass
class Variables:
    maskT: npx.ndarray
    maskU: npx.ndarray
    maskV: npx.ndarray
    kbot: npx.ndarray

jax.tree_util.register_dataclass(Variables, Variables.__dataclass_fields__.keys(), [])

vs = Variables(**{k: v for k, v in statedict.items() if k in Variables.__dataclass_fields__})

kernel_cpu = jax.jit(calc_topo_kernel, backend="cpu")
kernel_gpu = jax.jit(calc_topo_kernel, backend="cuda")

state_cpu = kernel_cpu(vs)
state_gpu = jax.device_get(kernel_gpu(vs))

passed = True
for key in state_cpu.keys():
    if not npx.allclose(state_cpu[key], state_gpu[key]):
        passed = False
        print(key)
        print(npx.abs(state_cpu[key].astype("float32") - state_gpu[key].astype("float32")).max())

if passed:
    print("All good!")
else:
    print("eek")
    sys.exit(1)

This prints:

$ python jaxbug.py
maskU
1.0
eek

Here's the pickle file containing the inputs to the function:

jaxbug.zip

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.0
python: 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='Ahab', release='6.5.0-41-generic', version='#41~22.04.2-Ubuntu SMP PREEMPT_DYNAMIC Mon Jun  3 11:32:55 UTC 2', machine='x86_64')

$ nvidia-smi
Wed Jul 10 23:42:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3060        Off | 00000000:01:00.0  On |                  N/A |
|  0%   52C    P2              40W / 170W |    990MiB / 12288MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     36922      C   python                                      104MiB |
+---------------------------------------------------------------------------------------+
vfdev-5 commented 3 months ago

@dionhaefner not sure what exactly the problem with jitted function here, ~but I think we can rewrite calc_topo_kernel function in a bit more clear way and the check will be also passing:~

dionhaefner commented 3 months ago

That's because your rewritten function isn't the same as the original, and you're executing it twice on the same device.

The issue is that this function gives wrong results on GPU but not on CPU.

vfdev-5 commented 3 months ago

yes, sorry, I got wrong the ops in the function. Non-jitted version has no issue between gpu / cpu results but I confirm that jitted version has some discrepancy.