pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 468 forks source link

Zero-copy between CUDA and XLA #6971

Open vanbasten23 opened 5 months ago

vanbasten23 commented 5 months ago

This issue will be used to track the work for zero-copy between CUDA and XLA.

Inspired by

I implemented a POC at https://github.com/pytorch/xla/pull/6970.

Current status:

  1. CUDA -> XLA
  2. XLA -> CUDA

Currently fails with error

Traceback (most recent call last):
  File "pytorch/xla/test/test_operations.py", line 2454, in test_aten_move_xla_to_cuda_zero_copy
    cuda_tensor = xla_tensor.cuda()
RuntimeError: tensor does not have a device

with GPU, I can see the stacktrace:

#0  0x00007f323098fc2e in __cxa_throw () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#1  0x00007f3230808c00 in c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) () from /usr/local/lib/python3.8/site-packages/torch/lib/libc10.so
#2  0x00007f31d8e1915f in at::TensorBase::options() const () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#3  0x00007f31d951acdd in at::native::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
    () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#4  0x00007f31da37f23b in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___to_copy>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat> > >, at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) ()
   from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#5  0x00007f31d9aadc55 in at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so

with prints:

xw32, file=/ansible/pytorch/aten/src/ATen/DLConvertor.cpp, line=126function=getATenDevice: ctx.device_type=2
xw32, file=/ansible/pytorch/aten/src/ATen/DLConvertor.cpp, line=133function=getATenDevice: ctx.device_type=2, ctx.device_id=0
xw32, file=/ansible/pytorch/aten/src/ATen/DLConvertor.cpp, line=308function=fromDLPack: device=cuda:0

Will look into it.

cc: @ysiraichi @JackCaoG @miladm

vanbasten23 commented 5 months ago

Ok, for moving a CUDA tensor containing a single value to XLA (case 1.2), I think I know what is happening and I think for this case it might be ok to go through CPU:

XLA_FALLBACK_CUDA=1 PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py TestGeneric.test_aten_move_scalar_cuda_to_xla_zero_copy fails with a segfault with callstack: https://gist.github.com/vanbasten23/c1dd0f19ca7abcd52d46dbd35a26f643 The segfault happens when we calculate the hash of the cuda tensor during std:memcpy. The DataCacheArena::DataCache is on host (CPU) and we are trying to copy from GPU to CPU, I think that's why it fails. I think it is ok to go through CPU in this case because:

vanbasten23 commented 5 months ago

So I ran BERT_pytorch\ model and it fails with an error when we move cuda tensor to the XLA device:

RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:465 : from_dlpack got array with non-default layout with minor-to-major dimensions (2,0,1), expected (2,1,0)

from https://github.com/openxla/xla/blob/f3553ed43a40d462aefe359a7c6a7ef441b6188c/xla/python/dlpack.cc#L445-L450

vanbasten23 commented 5 months ago

I figured it out. The error above only exist in IFRT. But since we are using PJRT, we don't have such issue. I added a test for it.

So now I'm getting another error: RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:993 : Check failed: pjrt_device == pjrt_data->buffer->device() callstack: https://gist.github.com/vanbasten23/4ce60b9a44c43d4948fc29e7ac8b596a

vanbasten23 commented 5 months ago

I used CUDA_VISIBLE_DEVICES=1 to constrain the device and got OOM:

root@xiowei-gpu:/ansible/pytorch# CUDA_VISIBLE_DEVICES=1  XLA_FALLBACK_CUDA=1 python xla/benchmarks/experiment_runner.py --suite-name=torchbench --accelerator=cuda --progress-bar  --model-config=\{\"model_name\":\"BERT_pytorch\"\} --experiment-config=\{\"accelerator\":\"cuda\",\"xla\":\"PJRT\",\"xla_flags\":null,\"dynamo\":\"openxla\",\"test\":\"train\"\}   --repeat 1

  File "/ansible/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 512, in optimized_mod
    result = _maybe_move_tensors_to_device(tuple(result), original_device)
  File "/ansible/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 165, in _maybe_move_tensors_to_device
    moved_tensor = tensor.to(target_device)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 158.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 79.88 MiB is free. Process 32460 has 15.69 GiB memory in use. Of the allocated memory 3.28 GiB is allocated by PyTorch, and 227.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
vanbasten23 commented 5 months ago

Well actually, I realized that the above OOM was run on my V100 machine. So I ran the same script and code on my A100 machine and it ran fine.