ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34k stars 5.78k forks source link

[core][compiled graphs] Failing test: `test_torch_tensor_dag.py::test_torch_tensor_exceptions[static_shape=True, direct_return=True, overlap_gpu_communication=True]` #48747

Open AndyUB opened 1 day ago

AndyUB commented 1 day ago

What happened + What you expected to happen

The test ray/python/ray/dag/tests/experimental/test_torch_tensor_dag.py::test_torch_tensor_exceptions[static_shape=True-direct_return=True-overlap_gpu_communication=True] fails locally.

The bug is probably reading the buffer allocated on the receiver side in NCCL P2P send/recv before the actual data is sent. Currently the receiver's buffer is all zeros so the output is all zeros. When I changed the allocation function to allocate torch.ones(...) * 100 instead, the actual output becomes [100, ..., 100].

An interesting finding is that when the code executes faster, this test always fails; but when I added a ton of print statements for debugging, it runs more slowly and the test sometimes passes.

Since this test has overlap_gpu_communication=True, it is likely related to overlapping GPU communication with computation. My guess is that the actor reading the tensor did not properly wait for the recv event to finish.

I checked out to the commit that most recently modified the test: #47586, as well as the current HEAD of the ray-project:master branch, and the test failed in either case.

Below is an example error message:

    @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
    @pytest.mark.parametrize("static_shape", [False, True])
    @pytest.mark.parametrize("direct_return", [False, True])
    @pytest.mark.parametrize("overlap_gpu_communication", [False, True])
    def test_torch_tensor_exceptions(
        ray_start_regular, static_shape, direct_return, overlap_gpu_communication
    ):
        """
        Test exceptions being thrown by a NCCL sending task.
        """
        if not USE_GPU:
            pytest.skip("NCCL tests require GPUs")

        assert (
            sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
        ), "This test requires at least 2 GPUs"

        actor_cls = TorchTensorWorker.options(num_gpus=1)

        sender = actor_cls.remote()
        receiver = actor_cls.remote()

        with InputNode() as inp:
            dag = sender.send_or_raise.bind(
                inp.shape, inp.dtype, inp.value, inp.raise_exception
            )
            dag = dag.with_type_hint(
                TorchTensorType(
                    _static_shape=static_shape,
                    _direct_return=direct_return,
                    transport="nccl",
                )
            )
            dag = receiver.recv.bind(dag)

        compiled_dag = dag.experimental_compile(
            _overlap_gpu_communication=overlap_gpu_communication
        )

        shape = (10,)
        dtype = torch.float16

        for i in range(3):
            i += 1

            ref = compiled_dag.execute(
                shape=shape,
                dtype=dtype,
                value=i,
                raise_exception=False,
            )
            result = ray.get(ref)
>           assert result == (i, shape, dtype)
E           assert (0.0, torch.S...torch.float16) == (2, (10,), torch.float16)
E             At index 0 diff: 0.0 != 2
E             Full diff:
E             - (2, (10,), torch.float16)
E             + (0.0, torch.Size([10]), torch.float16)

test_torch_tensor_dag.py:861: AssertionError

Versions / Dependencies

Newest version of Ray. Python: 3.9.

Reproduction script

https://github.com/ray-project/ray/blob/master/python/ray/dag/tests/experimental/test_torch_tensor_dag.py#L813

@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
@pytest.mark.parametrize("static_shape", [False, True])
@pytest.mark.parametrize("direct_return", [False, True])
@pytest.mark.parametrize("overlap_gpu_communication", [False, True])
def test_torch_tensor_exceptions(
    ray_start_regular, static_shape, direct_return, overlap_gpu_communication
):
    """
    Test exceptions being thrown by a NCCL sending task.
    """
    if not USE_GPU:
        pytest.skip("NCCL tests require GPUs")

    assert (
        sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
    ), "This test requires at least 2 GPUs"

    actor_cls = TorchTensorWorker.options(num_gpus=1)

    sender = actor_cls.remote()
    receiver = actor_cls.remote()

    with InputNode() as inp:
        dag = sender.send_or_raise.bind(
            inp.shape, inp.dtype, inp.value, inp.raise_exception
        )
        dag = dag.with_type_hint(
            TorchTensorType(
                _static_shape=static_shape,
                _direct_return=direct_return,
                transport="nccl",
            )
        )
        dag = receiver.recv.bind(dag)

    compiled_dag = dag.experimental_compile(
        _overlap_gpu_communication=overlap_gpu_communication
    )

    shape = (10,)
    dtype = torch.float16

    for i in range(3):
        i += 1

        ref = compiled_dag.execute(
            shape=shape,
            dtype=dtype,
            value=i,
            raise_exception=False,
        )
        result = ray.get(ref)
        assert result == (i, shape, dtype)

    # Application level exceptions are thrown to the end ray.get
    ref = compiled_dag.execute(
        shape=shape,
        dtype=dtype,
        value=i,
        raise_exception=True,
    )
    if static_shape or direct_return:
        with pytest.raises(RayChannelError):
            # TODO(swang): Ideally return the RuntimeError thrown by the
            # application instead of a generic RayChannelError.
            ray.get(ref)

        with pytest.raises(RayChannelError):
            # If using static_shape=True or direct_return=True, the DAG is not
            # usable after application-level exceptions.
            ref = compiled_dag.execute(
                shape=shape,
                dtype=dtype,
                value=i,
                raise_exception=False,
            )
    else:
        with pytest.raises(RuntimeError):
            ray.get(ref)

        # DAG should still be usable after application-level exceptions.
        ref = compiled_dag.execute(
            shape=shape,
            dtype=dtype,
            value=i,
            raise_exception=False,
        )
        result = ray.get(ref)
        assert result == (i, shape, dtype)

Issue Severity

High: It blocks me from completing my task.

ruisearch42 commented 5 hours ago

Hi @AndyUB, Thanks for reporting!

This does sound like a bug in the overlap functionality. However, when I reran the test (python/ray/dag/tests/experimental/test_torch_tensor_dag.py::test_torch_tensor_exceptions[True-True-True-ray_start_regular0]) 100 times on a node with L4 GPU, I wasn't able to reproduce it once. Our CI is also a bit flaky right now (fix in progress) so it's not easy to check how often this failed.

In your experience how often this fails? Any way to make it more reproducible? You mentioned "An interesting finding is that when the code executes faster, this test always fails", how did you make the code run faster? On a better GPU?

ruisearch42 commented 5 hours ago

btw, I think the issue is likely due to in _compute() we only sync on GPU recv stream, but not on CPU:

    def _compute(
        self,
        overlap_gpu_communication: bool,
        class_handle,
    ) -> bool:
        input_data = self.reset_and_wait_intermediate_future() 
    def reset_and_wait_intermediate_future(self) -> Any:
        future = self._intermediate_future
        self._intermediate_future = None
        return future.wait()
class GPUFuture(DAGOperationFuture[Any]):
    def wait(self) -> Any:
        """
        Wait for the future on the current CUDA stream and return the result from
        the GPU operation. This operation does not block CPU.
        """
        import cupy as cp

        current_stream = cp.cuda.get_current_stream()
        current_stream.wait_event(self._event)
        return self._buf

And the receiver _compute() operation runs the the following method (TorchTensorWorker.recv), which directly retrieves the item, shape, and dtype from the GPU tensor without waiting.

class TorchTensorWorker:
    def recv(self, tensor):
        # Check that tensor got loaded to the correct device.
        assert tensor.device == self.device
        return (tensor[0].item(), tensor.shape, tensor.dtype)

To fix this issue, we will probably need to make CPU synchronize on the recv stream in _compute(). cc: @stephanie-wang @rkooo567

stephanie-wang commented 2 hours ago

Re: test repro, you could try to insert a sleep on the recv stream before queuing the recv.

And the receiver _compute() operation runs the the following method (TorchTensorWorker.recv), which directly retrieves the item, shape, and dtype from the GPU tensor without waiting. To fix this issue, we will probably need to make CPU synchronize on the recv stream in _compute().

Not sure that's the whole story, the read of the item requires GPU->CPU movement and is supposed to get queued on the compute stream after syncing on the recv stream. It would be good to check that the read of the item is happening on the expected stream.