Open huzama opened 2 weeks ago
When using dynamo sharding inside torch.compile, I encounter the following error:
torch.compile
Input tensor is not an XLA tensor: XLAFloatType
It works fine without torch.ops.xla.dynamo_mark_sharding
torch.ops.xla.dynamo_mark_sharding
Here's a minimal code snippet to reproduce the bug:
import torch import torch_xla from torch_xla.core import xla_model as xm import torch_xla.distributed.spmd as xs from src.utils import init_spmd import torch_xla.runtime as xr xr.use_spmd() init_spmd(1) # Setup SPMD for your configuration spmd_mesh = xs.get_global_mesh() device_ids = list(range(spmd_mesh.size())) axis_names = str(tuple(spmd_mesh.shape().keys())) mesh_shape = list(spmd_mesh.shape().values()) def important_fn(inputs): torch.ops.xla.dynamo_mark_sharding( inputs, device_ids=device_ids, mesh_shape=mesh_shape, axis_names=axis_names, partition_spec="('fsdp', None)", ) inputs = inputs.to(torch.float32) return inputs.pow(2).mean(-1, keepdim=True) compiled = torch.compile(important_fn, backend="openxla", fullgraph=True) inputs = torch.randn((128, 4096), device=xm.xla_device()) compiled(inputs)
*** Begin stack trace *** tsl::CurrentStackTrace() torch_xla::bridge::GetXlaTensor(at::Tensor const&) _PyObject_MakeTpCall _PyEval_EvalFrameDefault _PyFunction_Vectorcall ... // repeated function calls ... PyEval_EvalCode ... __libc_start_main _start *** End stack trace ***
@wonjoolee95 can you take a look? Should be something easy to fix.
🐛 Bug Report
When using dynamo sharding inside
torch.compile
, I encounter the following error:It works fine without
torch.ops.xla.dynamo_mark_sharding
To Reproduce
Here's a minimal code snippet to reproduce the bug:
Stack Trace
Environment