A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
Apache License 2.0
5.06k stars 615 forks source link

Data corruption with JAX plugin #5617

Open kvablack opened 1 week ago

kvablack commented 1 week ago


nvidia-dali-cuda120==1.40.0, jax==0.4.31

Describe the bug.

We recently discovered a problem that when we used DALI, our training curves were mysteriously worse. We were able to fix it by adding a jax.block_until_ready() call after each train step, to foil JAX's asynchronous dispatch. I therefore hypothesized that there was some sort of data corruption going on, caused by a lack of synchronization between JAX and DALI.

If I understand correctly, when you request the next element, the JAX plugin roughly does the following steps:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # tells DALI that we are done with the output buffers
    # schedules the next fetch
    return element

I believe the problem is that jax.dlpack.from_dlpack and jnp.copy are both asynchronous. Therefore, pipe.release_outputs() is called before the copy actually occurs. When you request the next element, DALI can overwrite the output buffers before JAX is done reading from them.

I was able to fix the problem by bypassing the DALI JAX plugin, and adding the following line to my code:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # wait until the copy is done (THIS IS THE MISSING STEP)
    element = jax.block_until_ready(element)
    # tells DALI that we are done with the output buffers
    # schedules the next fetch
    return element

Of course, I suspect this forces JAX to flush the entire GPU pipeline, which is somewhat inefficient. However, I don't think there's any way around this without deeper integration between JAX and DALI (specifically, I think DALI needs to provide a DLPack deleter, although I'm no expert).

Minimum reproducible example

No response

Relevant log output

No response


No response

Check for duplicates

awolant commented 1 week ago

Hello @kvablack

thank you for reporting this issue. Do you have some standalone script to reproduce it reliably. It would help a lot as I had no luck in reproducing it so far.

As far as I remember DALI should pass DLPack deleter to JAX to be called when capsule is no longer needed. There might be some issue with it as you pointed out.

kvablack commented 1 week ago

Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step".

import jax
import jax.numpy as jnp
import numpy as np
from nvidia import dali
import tqdm

# this needs to be large to trigger the bug.
ARR_SIZE = 2**16

class ExternalSource:
    def __call__(self, sample_info: dali.types.SampleInfo):
        return [
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),

def get_pipe():
    def pipeline():
        outputs = dali.fn.external_source(

        outputs = [arr.gpu() for arr in outputs]

        return tuple(outputs)

    pipe = pipeline(device_id=0)
    return pipe

def get_next(pipe: dali.Pipeline):
    outputs = pipe.share_outputs()
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in outputs]
    # jax.block_until_ready(element)
    return element

def f(x):
    # matmul to simulate the train step. interestingly, this is required to reproduce the bug.
    return [y @ y.T for y in x]

if __name__ == "__main__":
    pipe = get_pipe()
    batches = []
    for _ in tqdm.trange(10):

    for i in tqdm.trange(len(batches)):
        for elem in batches[i]:
            # `x` is what the ExternalSource should have returned
            x = jnp.broadcast_to(jnp.arange(elem.shape[0])[:, None] + i * elem.shape[0], (elem.shape[0], ARR_SIZE))
            # `y` is what the "train step" should have returned
            y = x @ x.T
            # check that it matches the batch that came out of the pipeline
            assert jnp.all(y == elem), f"Failed on batch {i}\n\nExpected:\n{y}\n\nGot:\n{elem}"