When xm.save() is called on a subset of tensors, it calls down to _xla_sync_multi(sync_xla_data=True) which eventually calls down to XLAGraphExecutor::SyncTensorsGraph() with sync_ltc_data=true as if it's doing a step marker. This causes buffer aliasing to be performed by XLAGraphExecutor::Compile(). However, according to the detailed comments in xla_graph_executor.cpp#L1336, buffer aliasing should only be performed at a step marker when all live tensors are being sync'd. This is obviously not the case for xm.save().
As a result, the xm.mark_step() following xm.save() will have parameters which refer to donated buffers which are deleted by PJRT, triggering thisXLA_CHECK() failure in PjRtData::GetHandle().
To Reproduce
A small example to reproduce.
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import tempfile
def main():
device = xm.xla_device()
t0 = torch.randn(100, 100, device=device)
t1 = torch.randn(100, 100, device=device)
xm.mark_step()
t2 = t0 + t1
t1.add_(1)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile:
# Save the update value of t1, which will cause its device buffer
# to be donated, overwriting the old t1 value which is needed to
# compute t2.
# Adding step marker right before xm.save() avoids the crash.
if False:
xm.mark_step()
xm.save({ 't1' : t1, }, tmpfile.name)
# CRASH
xm.mark_step()
if __name__ == "__main__":
main()
Environment
Reproducible on XLA backend [CPU/TPU/CUDA]: neuron trn1
Our PJRT device plugin will mark a PJRT_Buffer as deleted if it's donated to an output buffer, causing the above mentioned XLA_CHECK() to fail. In case other device plugins do not have this behavior, here is a patch that forces a crash in xla_graph_executor.cpp.
🐛 Bug
When
xm.save()
is called on a subset of tensors, it calls down to_xla_sync_multi(sync_xla_data=True)
which eventually calls down toXLAGraphExecutor::SyncTensorsGraph()
withsync_ltc_data=true
as if it's doing a step marker. This causes buffer aliasing to be performed byXLAGraphExecutor::Compile()
. However, according to the detailed comments in xla_graph_executor.cpp#L1336, buffer aliasing should only be performed at a step marker when all live tensors are being sync'd. This is obviously not the case forxm.save()
.As a result, the
xm.mark_step()
followingxm.save()
will have parameters which refer to donated buffers which are deleted by PJRT, triggering thisXLA_CHECK()
failure inPjRtData::GetHandle()
.To Reproduce
A small example to reproduce.
Environment
Additional context