NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.06k stars 615 forks source link

Data corruption with JAX plugin #5617

Open kvablack opened 1 week ago

kvablack commented 1 week ago

Version

nvidia-dali-cuda120==1.40.0, jax==0.4.31

Describe the bug.

We recently discovered a problem that when we used DALI, our training curves were mysteriously worse. We were able to fix it by adding a jax.block_until_ready() call after each train step, to foil JAX's asynchronous dispatch. I therefore hypothesized that there was some sort of data corruption going on, caused by a lack of synchronization between JAX and DALI.

If I understand correctly, when you request the next element, the JAX plugin roughly does the following steps:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # tells DALI that we are done with the output buffers
    pipe.release_outputs()
    # schedules the next fetch
    pipe.schedule_run()
    return element

I believe the problem is that jax.dlpack.from_dlpack and jnp.copy are both asynchronous. Therefore, pipe.release_outputs() is called before the copy actually occurs. When you request the next element, DALI can overwrite the output buffers before JAX is done reading from them.

I was able to fix the problem by bypassing the DALI JAX plugin, and adding the following line to my code:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # wait until the copy is done (THIS IS THE MISSING STEP)
    element = jax.block_until_ready(element)
    # tells DALI that we are done with the output buffers
    pipe.release_outputs()
    # schedules the next fetch
    pipe.schedule_run()
    return element

Of course, I suspect this forces JAX to flush the entire GPU pipeline, which is somewhat inefficient. However, I don't think there's any way around this without deeper integration between JAX and DALI (specifically, I think DALI needs to provide a DLPack deleter, although I'm no expert).

Minimum reproducible example

No response

Relevant log output

No response

Other/Misc.

No response

Check for duplicates

awolant commented 1 week ago

Hello @kvablack

thank you for reporting this issue. Do you have some standalone script to reproduce it reliably. It would help a lot as I had no luck in reproducing it so far.

As far as I remember DALI should pass DLPack deleter to JAX to be called when capsule is no longer needed. There might be some issue with it as you pointed out.

kvablack commented 1 week ago

Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step".

import jax
import jax.numpy as jnp
import numpy as np
from nvidia import dali
import tqdm

# this needs to be large to trigger the bug.
ARR_SIZE = 2**16

class ExternalSource:
    def __call__(self, sample_info: dali.types.SampleInfo):
        return [
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
        ]

def get_pipe():
    @dali.pipeline_def(
        batch_size=2,
        num_threads=1,
        prefetch_queue_depth=6,
        py_start_method="spawn",
    )
    def pipeline():
        outputs = dali.fn.external_source(
            source=ExternalSource(),
            num_outputs=2,
            batch=False,
            parallel=True,
        )

        outputs = [arr.gpu() for arr in outputs]

        return tuple(outputs)

    pipe = pipeline(device_id=0)
    pipe.build()
    pipe.schedule_run()
    return pipe

def get_next(pipe: dali.Pipeline):
    outputs = pipe.share_outputs()
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in outputs]
    # UNCOMMENT THIS LINE TO MAKE THE ASSERTIONS PASS.
    # jax.block_until_ready(element)
    pipe.release_outputs()
    pipe.schedule_run()
    return element

@jax.jit
def f(x):
    # matmul to simulate the train step. interestingly, this is required to reproduce the bug.
    return [y @ y.T for y in x]

if __name__ == "__main__":
    pipe = get_pipe()
    batches = []
    for _ in tqdm.trange(10):
        batches.append(f(get_next(pipe)))

    for i in tqdm.trange(len(batches)):
        for elem in batches[i]:
            # `x` is what the ExternalSource should have returned
            x = jnp.broadcast_to(jnp.arange(elem.shape[0])[:, None] + i * elem.shape[0], (elem.shape[0], ARR_SIZE))
            # `y` is what the "train step" should have returned
            y = x @ x.T
            # check that it matches the batch that came out of the pipeline
            assert jnp.all(y == elem), f"Failed on batch {i}\n\nExpected:\n{y}\n\nGot:\n{elem}"