Open ysiraichi opened 5 months ago
This behavior should be the results of our functionalization pass. @alanwaketan to confirm the expected behavior. Either way, let's have a dlpack
documentation/tutorial that goes through example use cases and fully explains correct behavior @ysiraichi.
Thanks for the issue. I checked buffer pointer at more places:
>>> t0 = torch.arange(10, device=xm.xla_device())
>>> xm.mark_step(wait=True)
>>>
>>> capsule = xdlpack.to_dlpack(t0)
>>> t1 = xdlpack.from_dlpack(capsule)
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
True
>>>
>>> t0[0] = 100
>>> xm.mark_step()
>>>
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
True
>>> print(t0.eq(t1).all().item())
False
>>>
>>> print(torch_xla._XLAC._unsafe_buffer_pointer(t0)== torch_xla._XLAC._unsafe_buffer_pointer(t1))
False
Could you elaborate on That's because even though functionalization emulates views and mutation, PyTorch/XLA doesn't really have the concept of views and can't mutate a given tensor.
? Do you mean when we do t0[0]=100
, the underlying pjrt buffer is not mutated hence t1
is not updated, even though t0 and t1 share the same storage? Let me also look into what torch_xla does when we do t0[0]=100
Yes, exactly. In summary, functionalized lazy tensors is composed of:
Tensor(
impl=FunctionalTensorWrapper(
value=Tensor(
impl=XLATensorImpl(
tensor=XLATensor(handle or tensor_data or ir_value)
)
)
)
)
Suppose t0
and t1
share the same storage using the DLPack API. Whenever an in-place operation is called, e.g. t0.add_(1)
, the functionalization layer actually calls the functional variant (XLANativeFunctions::add
), which generates a new XLATensor
. Later, that is wrapped by a new FunctionalTensorWrapper
(let's call it temp
). In the end, the functionalization layer replaces the FunctionalTensorWrapper::value
of t0
by the one inside temp
. Thus, t0
ends up with the updated value, while t1
remains with the old one.
Hmm. Not sure I get it. Could you explain a bit more?
That's a helper where we can bridge information through intermediate tensors created by functionalization for in-place ops.
When we do the in-place op t0[0] = 100
, I see XLANativeFunctions::_propagate_xla_data
invoked twice by:
in sequence. So it seems the helper is already being used?
Here's how I think we could use propagate_xla_data
for solving this problem. Note that this is not a solution, but an initial idea. In summary, whenever it's called inside the dispatch of an in-place operation, we would need to:
alias_id == unique_id
) that holds the original shared buffer
XLATensor
for thatThis, however, won't work. Once we call torch._sync
on the original XLA tensor, we will run the in-place operation again, which might give incorrect results.
On another note, we could use this (propagate_xla_data
) for warning the user that they are not really modifying the underlying storage. Basically, check whether the tensor we are calling the in-place operation on shares storage (again, with a new XLATensor
flag).
🐛 Bug
In the example below, we have 2 tensors:
t0
andt1
.t1
is created from a DLPack capsule generated fromt0
. So, we could say they share the same storage. However, after modifyingt0
, we see that this change doesn't reflectt1
. Furthermore, their buffer pointer is different.This is actually expected. That's because even though functionalization emulates views and mutation, PyTorch/XLA doesn't really have the concept of views and can't mutate a given tensor.
That said, this could be unexpected behavior from the user point-of-view. When using DLPack to alias (i) CUDA and (ii) XLA tensors, in-place operations on (i) do propagate to (ii), but not the other way around.
I think that even if this is an expected limitation, it should be documented somewhere. Or, even better, we should warn the user if they try to use an in-place operation on an DLPack created XLA tensor (e.g. by having a flag
XLATensor::dlpack_created
).Environment
cc @miladm @JackCaoG @vanbasten23 @lezcano