pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.39k stars 429 forks source link

[Dynamo SPMD Sharding]: Error - Input tensor is not an XLA tensor: XLAFloatType #7645

Open huzama opened 2 weeks ago

huzama commented 2 weeks ago

🐛 Bug Report

When using dynamo sharding inside torch.compile, I encounter the following error:

Input tensor is not an XLA tensor: XLAFloatType

It works fine without torch.ops.xla.dynamo_mark_sharding

To Reproduce

Here's a minimal code snippet to reproduce the bug:

import torch
import torch_xla
from torch_xla.core import xla_model as xm
import torch_xla.distributed.spmd as xs
from src.utils import init_spmd
import torch_xla.runtime as xr

xr.use_spmd()

init_spmd(1) # Setup SPMD for your configuration

spmd_mesh = xs.get_global_mesh() 

device_ids = list(range(spmd_mesh.size()))
axis_names = str(tuple(spmd_mesh.shape().keys()))
mesh_shape = list(spmd_mesh.shape().values())

def important_fn(inputs):
    torch.ops.xla.dynamo_mark_sharding(
        inputs,
        device_ids=device_ids,
        mesh_shape=mesh_shape,
        axis_names=axis_names,
        partition_spec="('fsdp', None)",
    )
    inputs = inputs.to(torch.float32)
    return inputs.pow(2).mean(-1, keepdim=True)

compiled = torch.compile(important_fn, backend="openxla", fullgraph=True)

inputs = torch.randn((128, 4096), device=xm.xla_device())

compiled(inputs)

Stack Trace

*** Begin stack trace ***
    tsl::CurrentStackTrace()
    torch_xla::bridge::GetXlaTensor(at::Tensor const&)
    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    ...
    // repeated function calls
    ...
    PyEval_EvalCode
    ...
    __libc_start_main
    _start
*** End stack trace ***

Environment

JackCaoG commented 2 weeks ago

@wonjoolee95 can you take a look? Should be something easy to fix.