pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

`xm.save()` should not call `_xla_sync_multi()` with `sync_xla_data=True` #8422

Open mcuiaws opened 3 days ago

mcuiaws commented 3 days ago

🐛 Bug

When xm.save() is called on a subset of tensors, it calls down to _xla_sync_multi(sync_xla_data=True) which eventually calls down to XLAGraphExecutor::SyncTensorsGraph() with sync_ltc_data=true as if it's doing a step marker. This causes buffer aliasing to be performed by XLAGraphExecutor::Compile(). However, according to the detailed comments in xla_graph_executor.cpp#L1336, buffer aliasing should only be performed at a step marker when all live tensors are being sync'd. This is obviously not the case for xm.save().

As a result, the xm.mark_step() following xm.save() will have parameters which refer to donated buffers which are deleted by PJRT, triggering this XLA_CHECK() failure in PjRtData::GetHandle().

To Reproduce

A small example to reproduce.

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import tempfile

def main():
    device = xm.xla_device()
    t0 = torch.randn(100, 100, device=device)
    t1 = torch.randn(100, 100, device=device)
    xm.mark_step()

    t2 = t0 + t1
    t1.add_(1)

    with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile:
        # Save the update value of t1, which will cause its device buffer
        # to be donated, overwriting the old t1 value which is needed to
        # compute t2.
        # Adding step marker right before xm.save() avoids the crash.
        if False:
            xm.mark_step()
        xm.save({ 't1' : t1, }, tmpfile.name)

    # CRASH
    xm.mark_step()

if __name__ == "__main__":
    main()

Environment

Additional context

mcuiaws commented 3 days ago

Our PJRT device plugin will mark a PJRT_Buffer as deleted if it's donated to an output buffer, causing the above mentioned XLA_CHECK() to fail. In case other device plugins do not have this behavior, here is a patch that forces a crash in xla_graph_executor.cpp.