pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 477 forks source link

Sharing tensor storage (with DLPack) results in unexpected behavior. #7304

Open ysiraichi opened 4 months ago

ysiraichi commented 4 months ago

🐛 Bug

Consider the following tensors:

>>> t0 = torch.arange(6).view(2, 3).to("xla")
# t0 = [[0, 1, 2], [3, 4, 5]]
>>> t1 = t0[1]
# t1 = [3, 4, 5]
>>> t2 = torch_xla.dlpack.from_xla_cuda_to_cuda(t0) # t2 is on CUDA
# t2 = [[0, 1, 2], [3, 4, 5]]

Then, consider the following sequence of in-place operations:

# (I)
>>> t2[1, 0] = 10  
# t0 = [[0, 1, 2], [10, 4, 5]]
# t1 = [3, 4, 5]
# t2 = [[0, 1, 2], [10, 4, 5]]

# (II)
>>> t1[0] *= 2  
# t0 = [[0, 1, 2], [6, 4, 5]]
# t1 = [6, 4, 5]
# t2 = [[0, 1, 2], [10, 4, 5]]

Due to #7198, we already expect that the effects of the in-place operation (II) is not going to be propagated back to the CUDA tensor t2. However, since t1 is a view of t0, I would expect (II) to use the values updated by (I). However, (II) clearly uses the value of t0 before (I), unexpectedly resulting in t0[1, 0] == 6.

This problem happens because of how the functional layer applies updates to the base tensor of a view. Even though the functional implementation does try to synchronize (i.e. regenerating the view from base), the regeneration doesn't happen because the functional wrapper of t0 doesn't know it has changed.

Expected behavior

Given #7198, I think a reasonable result would be:

>>> t1[0] *= 2  
# t0 = [[0, 1, 2], [20, 4, 5]]
# t1 = [20, 4, 5]
# t2 = [[0, 1, 2], [10, 4, 5]]

i.e. run the in-place operation with the updated base values.

Environment

cc @miladm @JackCaoG @vanbasten23 @bdhirsh @lezcano

ManfeiBai commented 4 months ago

Hi, @vanbasten23, is that ok to assign this ticket to you?

vanbasten23 commented 3 months ago

Thanks for reporting the issue. I did some investigation.

I think there is a typo in your repro:

# (I)
>>> t2[1, 0] = 10  
# t0 = [[0, 1, 2], [10, 4, 5]]
# t1 = [3, 4, 5]
# t2 = [[0, 1, 2], [10, 4, 5]]

After we do t2[1,0]=10, t1 should be [10, 4, 5]. Could you double check?

Second, regarding since t1 is a view of t0, I'm not sure about it. When we do t1 = t0[1], torch_xla operation XLANativeFunctions::select_copy is called. So t1 and t0 may not share the same buffer:

>>> import torch, torch_xla
>>> import torch_xla.debug.metrics as met
>>> import torch_xla.core.xla_model as xm
>>> t0 = torch.arange(6).view(2, 3).to("xla")
>>> t1 = t0[1]
>>> torch_xla._XLAC._unsafe_buffer_pointer(t0)
139902419206144
>>> torch_xla._XLAC._unsafe_buffer_pointer(t1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: torch_xla/csrc/init_python_bindings.cpp:2681 : Could not get the buffer pointer for XLATensor with IR that's not DeviceData
>>> print(t1)
I0000 00:00:1721172456.872579   12480 cuda_dnn.cc:530] Loaded cuDNN version 8902
tensor([3, 4, 5], device='xla:0')
>>> torch_xla._XLAC._unsafe_buffer_pointer(t1)
139902419206400

It means t1 is output of the XLA graph and t0 is the input. So if you modify t1 as t1[0] *= 2, it shouldn't propagate to t0. Wdyt? @ysiraichi

ysiraichi commented 3 months ago

After we do t2[1,0]=10, t1 should be [10, 4, 5]. Could you double check?

I'm pretty sure it's not for the following reasons:

So if you modify t1 as t1[0] *= 2, it shouldn't propagate to t0.

I believe it should. The one that guarantees mutation is propagated through aliasing relations is the functionalization layer. In other words, even though they don't actually share memory, the functionalization layer knows that whenever an in-place operation takes place, it should update the base tensor (eagerly) and the other views (lazyly).