pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 455 forks source link

Zero copy tensor conversion between xla:gpu and torch.cuda #4692

Open cicirori opened 1 year ago

cicirori commented 1 year ago

Currently, switching between lazy and eager can be a huge overhead even when using the same device. This is mainly due to the ir graph execution and the conversion of tensor device types. However, the latter is not necessary, I think it's historical reasons (xrt), which can be seen from the interface name TransferToServer/TransferFromServer. Even if it is from gpu to the same gpu, it must be redirected from the cpu.

I'm implementing a PoC so that xla_tensor.to('cuda') and cuda_tensor.to('xla') are actually zero copy. So far it could running a eager/lazy mixed mnist.

But there should be some problems here, I used _to_copy op but there is no copy actually, I wonder if there will be problems with the backward direction during training.

I am currently considering how to implement zero copy while ensuring correctness, and would like to know if the community has any relevant experience.

JackCaoG commented 1 year ago

Based on my understanding, what we need to do is to make that gpu buffer somehow recognized by the PJRT runtime (as a PJRTBuffer) . I am curious what's your current approach.

cicirori commented 1 year ago

Based on my understanding, what we need to do is to make that gpu buffer somehow recognized by the PJRT runtime (as a PJRTBuffer) . I am curious what's your current approach.

https://github.com/openxla/xla/blob/3c22aa8d716edfc4821b085b920534b4b01e9438/xla/python/dlpack.cc#L283 https://github.com/pytorch/pytorch/blob/d6dd67a2488c7e17fbf010eee805f1cb2d64ba28/aten/src/ATen/DLConvertor.cpp#L217

dlpack implementation of xla and torch would be good references.

kevint324 commented 1 year ago

from the Toy Example

import torch
def fn(x, y):
    a = torch.sin(x).cuda()
    b = torch.sin(y).cuda()
    return a + b
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn(input_tensor, input_tensor)

It seems the same issue has been addressed in the eager/dynamo mixed scenario with the inductor backend.

Seems it could handle Zero copy tensor conversion between inductor and torch.cuda ?

When I changed the backend from ="inductor" to ="torchxla_trace_once" I got:

Traceback (most recent call last):
  File "xx.py", line 9, in <module>
    a = new_fn(input_tensor, input_tensor)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "xx.py", line 2, in fn
    def fn(x, y):
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/backends/torchxla.py", line 24, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch_xla/core/dynamo_bridge.py", line 197, in extract_compiled_graph
    assert all(
AssertionError: All tensors should be on xla

I'm not reporting a issue, just saying the relevant experience might reside in the dynamo bridge?

cicirori commented 1 year ago

from the Toy Example

import torch
def fn(x, y):
    a = torch.sin(x).cuda()
    b = torch.sin(y).cuda()
    return a + b
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn(input_tensor, input_tensor)

It seems the same issue has been addressed in the eager/dynamo mixed scenario with the inductor backend.

Seems it could handle Zero copy tensor conversion between inductor and torch.cuda ?

When I changed the backend from ="inductor" to ="torchxla_trace_once" I got:

Traceback (most recent call last):
  File "xx.py", line 9, in <module>
    a = new_fn(input_tensor, input_tensor)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "xx.py", line 2, in fn
    def fn(x, y):
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch/_dynamo/backends/torchxla.py", line 24, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/opt/conda/envs/python38/lib/python3.8/site-packages/torch_xla/core/dynamo_bridge.py", line 197, in extract_compiled_graph
    assert all(
AssertionError: All tensors should be on xla

I'm not reporting a issue, just saying the relevant experience might reside in the dynamo bridge?

@kevint324 In fact, as I understand it, the dynamo inductor backend does not give enough experience with zero copy between torch xla + eager. Because even with inductor, the device property of the tensor is not affected, and a tensor at the boundary between dynamo and eager can naturally be used in both scenarios. But in torch xla, tensor has a unique device that is different from cuda. And the xla tensor is opaque, which makes it impossible to interact with the cuda tensor.

If torch xla was still using xrt, this would be a tricky problem. But for now xrt is confirmed to be deprecated. I think it's time to rethink the interaction between xla and cuda tensor.