jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

JAX/TF2 on multi-GPU silently corrupted computations #3802

Closed david-berthelot closed 3 years ago

david-berthelot commented 4 years ago

When using 1 or 2 GPUs there's no problem. When using 3 or more GPUs the issue manifests itself resulting in corrupted computations. This was tested on 4 V100 GPUs. There's a workaround (see commented code).

Code for reproduction:

# The bug only shows for physical GPUs, and only when there 3 or more GPUs.
# It silently corrupts the results with no warning or error message.

import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf

# The following is a workaround from prevents the issue:
# tf.config.experimental.set_visible_devices([], "GPU")

reproduce_bug = True
noop = jax.pmap(lambda x: x)
n = jax.device_count()

def check(delta):
    v = noop(jn.ones((n, 1)) * delta).sum()
    assert v == n * delta, (v, n * delta)

check(1)
if reproduce_bug:
    data = tf.data.Dataset.from_tensor_slices(dict(x=np.random.uniform(-1, 1, size=(16, 1))))

check(2)
# Output is
# Traceback (most recent call last):
#  File "<stdin>", line 1, in <module>
#  File "<stdin>", line 3, in check
# AssertionError: (DeviceArray(7., dtype=float32), 8)
hawkinsp commented 4 years ago

I'm having trouble reproducing this. Can you please share the versions of:

you are using?

Do I need to do anything other than run the script (python repro.py)?

david-berthelot commented 4 years ago
hawkinsp commented 4 years ago

I'm still having trouble reproducing this. Does it reproduce every time for you? I did the following:

cat > t.py <<EOF Here I pasted the code from first Github comment, unmodified. EOF

python3 t.py



The bug does not seem to reproduce for me...

Any hints on how I can reproduce this?
david-berthelot commented 4 years ago

Thanks for the quick response.

I found the main difference, the issue happens with driver (which was the one available at the time I created my VMs): NVIDIA-SMI 440.100 Driver Version: 440.100 CUDA Version: 10.2 Upgrading my VM to NVIDIA-SMI 450.51.05 Driver Version: 450.51.05 CUDA Version: 11.0 Fixes the problem, or at least I couldn't reproduce in 10 tries. While before I could reproduce it under 3 tries.

hawkinsp commented 4 years ago

I'm still having trouble reproducing this. My current attempt is a Ubuntu 18.04 VM in which I run:

sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
echo deb http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 / | sudo tee -a /etc/apt/sources.list.d/cuda.list
echo deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 / | sudo tee -a /etc/apt/sources.list.d/cuda.list
sudo apt update
sudo apt dist-upgrade
sudo apt install cuda-drivers-440 cuda-command-line-tools-10-1 cuda-cudart-10-1 libcublas10=10.1.0.105-1 cuda-cufft-10-1 cuda-cusparse-10-1 cuda-curand-10-1 cuda-cusolver-10-1 libcudnn7 python3-pip

pip3 install --upgrade pip
python3 -m pip install https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
python3 -m pip install jax==0.1.72 tensorflow==2.1.0

cat > t.py <<EOF
Here I pasted the code from first Github comment, unmodified.
EOF

python3 t.py

and still no error even if I try a few times.

nvidia-smi reports:

$ nvidia-smi
Tue Jul 21 20:58:48 2020
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.100      Driver Version: 440.100      CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    46W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   37C    P0    40W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   38C    P0    49W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:07.0 Off |                    0 |
| N/A   39C    P0    57W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
david-berthelot commented 4 years ago

In my case, I had created an 8 GPU instance. And initially I thought the issue could be reproduced consistently but it seems stochastic. What I found is the following sequence reproduced it for me on one of the VMs:

CUDA_VISIBLE_DEVICES=0,1,2,3 python t.py
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python t.py
# This one fails
CUDA_VISIBLE_DEVICES=0,1,2,3 python t.py

I'm not sure what else to recommend if you still can't reproduce past that point except maybe giving you access to the VM on which I can reproduce the issue.

hawkinsp commented 4 years ago

Thanks! I can now reproduce. It seems it's important to have more than 4 GPUs and then to run something using only a subset of them.

hawkinsp commented 4 years ago

I'm able to reproduce with

tf-nightly=2.4.0.dev20200722

i.e., with an up to date TF and JAX.

hawkinsp commented 4 years ago

A reliable reproduction on my VM (set up as described above, with 8xV100 is to run):

CUDA_VISIBLE_DEVICES=0,1,2 TF_FORCE_GPU_ALLOW_GROWTH=true XLA_PYTHON_CLIENT_ALLOCATOR=platform python t.py
TF_FORCE_GPU_ALLOW_GROWTH=true XLA_PYTHON_CLIENT_ALLOCATOR=platform python t.py

This seems to triggers a CUDA_ERROR_ILLEGAL_ADDRESS with high probability. The faulting kernel is always broadcast_2 (a JAX-generated kernel) which is almost trivial:

.version 6.0
.target sm_70
.address_size 64

        // .globl       broadcast_2

.visible .entry broadcast_2(
        .param .u64 broadcast_2_param_0,
        .param .u64 broadcast_2_param_1
)
.reqntid 8, 1, 1
{
        .reg .b32       %r<3>;
        .reg .b64       %rd<7>;

        ld.param.u64    %rd1, [broadcast_2_param_0];
        ld.param.u64    %rd2, [broadcast_2_param_1];
        cvta.to.global.u64      %rd3, %rd2;
        cvta.to.global.u64      %rd4, %rd1;
        mov.u32         %r1, %tid.x;
        ld.global.nc.u32        %r2, [%rd3];
        mul.wide.u32    %rd5, %r1, 4;
        add.s64         %rd6, %rd4, %rd5;
        st.global.u32   [%rd6], %r2;
        ret;

}

and I believe it's the ld.global.nc.u32 that is faulting because the symptom from cuda-memcheck is:

========= Invalid __global__ read of size 4
=========     at 0x00000040 in broadcast_2
=========     by thread (7,0,0) in block (0,0,0)
=========     Address 0x7efc4e200000 is out of bounds
=========     Device Frame:broadcast_2 (broadcast_2 : 0x40)

and as far as I can see there's only that one global read in the kernel.

I caught the error with both cuda-memcheck and TF_CPP_VMODULE=kernel_thunk=100 enabled, which told me that the faulting address was at one of the buffers that XLA passed to the kernel, so likely the kernel and things related to code generation are not at fault, instead something runtime-related is at fault.

hawkinsp commented 4 years ago

Aha! I think I've figured this one out.

JAX and TF both use the stream_executor library inside the TensorFlow tree to interact with CUDA. stream_executor has an optimization where it caches the last GPU context it set in thread-local storage, and skips the call to cuCtxSetCurrent if it thinks the current CUDA context has not changed since last time it was set: https://github.com/tensorflow/tensorflow/blob/001ec7efbed18e9581e859513c5acc76e5aabbe9/tensorflow/stream_executor/cuda/cuda_driver.cc#L204

Inside Google we like linking our binaries statically (since you get hermetic artifacts, at the cost of requiring much more rebuilding), and if we are using JAX and TF together they both share a single copy of stream_executor. So they both use the same thread-local cache.

However, in opensource, we distribute JAX and TF as separate Python plugins, each with their own copy of stream_executor, built with private symbol visibility. This means that rather than having one cache, we have two. They interact badly if you mix TF and JAX in the same binary: TF will change the current GPU, and JAX will fail to notice this, and vice versa. The net effect is that we end up running things on the wrong GPU!

I suspect the right fix is to flush the cache at the Python API boundary, which is certainly something I can do on the JAX side.

TF should probably do the same, given that this isn't a problem specific to JAX: it could be any GPU-using Python library that does this.

hawkinsp commented 4 years ago

Apparently there are two bugs here: an illegal memory access and a wrong output.

I fixed the issue described above in my client, but it seems there's something else also going on here. Here's a slightly different repro:

from absl import logging
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf

# The following is a workaround from prevents the issue:
# tf.config.experimental.set_visible_devices([], "GPU")

reproduce_bug = True

if reproduce_bug:
  data = tf.data.Dataset.from_tensor_slices(dict(x=np.random.uniform(-1, 1, size=(16, 1))))

logging.error("JAX time!")
n = jax.device_count()
x = jn.array(3)
y = jax.device_put(x, jax.devices()[1])
print(y)

Running this code:

CUDA_VISIBLE_DEVICES=0,3 TF_FORCE_GPU_ALLOW_GROWTH=true XLA_PYTHON_CLIENT_ALLOCATOR=platform python ~/t.py

prints

0

when it should print 3.

It appears that it matters which pair of GPUs are used. On the GCP VM I'm using, the GPU topology is:

$ nvidia-smi topo -m
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity
GPU0     X      NV3     PHB     PHB     PHB     NV3     PHB     PHB     0-31
GPU1    NV3      X      PHB     NV3     PHB     PHB     PHB     PHB     0-31
GPU2    PHB     PHB      X      NV3     PHB     PHB     PHB     NV3     0-31
GPU3    PHB     NV3     NV3      X      PHB     PHB     PHB     PHB     0-31
GPU4    PHB     PHB     PHB     PHB      X      NV3     NV3     PHB     0-31
GPU5    NV3     PHB     PHB     PHB     NV3      X      PHB     PHB     0-31
GPU6    PHB     PHB     PHB     PHB     NV3     PHB      X      NV3     0-31
GPU7    PHB     PHB     NV3     PHB     PHB     PHB     NV3      X      0-31

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

It appears that the example succeeds iff the two chosen GPUs are connected by NVLink. For pairs of GPUs connected only by PCIe (PHB) the example fails.

I'm now speculating this happens because of TensorFlow changing something about the GPU peer access configuration. That said, I'm fairly mystified why it would matter. My next action is to try reducing the test case further.

hawkinsp commented 4 years ago

Here's a reproduction that requires neither JAX nor TF, which makes it highly likely to be NVidia's bug (unless I'm holding the APIs wrong):

#include <cstdlib>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>

void CheckCuda(CUresult result, const char* file, int line) {
  if (result == CUDA_SUCCESS) {
    return;
  }
  const char* name;
  cuGetErrorName(result, &name);
  const char* message;
  cuGetErrorString(result, &message);
  std::cerr << file << "(" << line << "): " << name << ", " << message << std::endl;
  exit(-1);
}

void CheckCuda(cudaError_t result, const char* file, int line) {
  if (result == cudaSuccess) {
    return;
  }
  const char* name = cudaGetErrorName(result);
  const char* message = cudaGetErrorString(result);
  std::cerr << file << "(" << line << "): " << name << ", " << message << std::endl;
  exit(-1);
}

#define CHECK_CUDA(result) CheckCuda(result, __FILE__, __LINE__)

int main(int argc, char** argv) {
  CHECK_CUDA(cuInit(0));
  int num_devices;
  CHECK_CUDA(cuDeviceGetCount(&num_devices));
  std::vector<CUdevice> devices(num_devices);
  for (int i = 0; i < num_devices; ++i) {
    CHECK_CUDA(cuDeviceGet(&devices[i], i));
  }
  std::vector<CUcontext> contexts(num_devices);
  for (int i = 0; i < num_devices; ++i) {
    CHECK_CUDA(cuDevicePrimaryCtxRetain(&contexts[i], devices[i]));
  }
  for (int i = 0; i < num_devices; ++i) {
    CHECK_CUDA(cuCtxSetCurrent(contexts[i]));
    for (int j = 0; j < num_devices; ++j) {
      if (i != j) {
        CHECK_CUDA(cuCtxEnablePeerAccess(contexts[j], /*flags=*/0));
      }
    }
  }
  CHECK_CUDA(cuCtxSetCurrent(contexts[0]));
  CUdeviceptr buf0;
  CHECK_CUDA(cuMemAlloc(&buf0, 4));
  std::int32_t input = 3;
  CHECK_CUDA(cuMemcpyHtoD(buf0, &input, 4));

  CHECK_CUDA(cuCtxSetCurrent(contexts[1]));
  CUdeviceptr buf1;
  CHECK_CUDA(cuMemAlloc(&buf1, 4));

  CHECK_CUDA(cuCtxSetCurrent(contexts[0]));
  CHECK_CUDA(cuMemcpyDtoD(buf1, buf0, 4));

  CHECK_CUDA(cuCtxSetCurrent(contexts[1]));
  std::int32_t output;
  CHECK_CUDA(cuMemcpyDtoH(&output, buf1, 4));
  if (output != input) {
    std::cerr << "Output does not match input: " << input << " vs " << output << std::endl;
    return -1;
  }
  std::cerr << "Success; output and input match." << std::endl;
  return 0;
}
$ c++ -O3 -Wall -std=c++11  repro.cc -o repro -I/usr/local/cuda-10.1/include -lcuda -lcudart -L/usr/local/cuda-10.1/lib64
$ ./repro
Success; output and input match.
$ CUDA_VISIBLE_DEVICES=0,3 ./repro
Output does not match input: 3 vs 0

I'll follow up with NVidia (edit: filed NVidia partners bug ~1787563~ 3101818 )

hawkinsp commented 3 years ago

NVidia closed the bug because they were unable to commit engineering resources to look into it.

Closing because there's no action we can take here. We can't fix NVidia driver/runtime bugs.