pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

xm.send_cpu_data_to_device cannot support 2d data and 4d mesh #8012

Open fengyang0317 opened 1 week ago

fengyang0317 commented 1 week ago

🐛 Bug

xm.send_cpu_data_to_device cannot support 2d data and 4d mesh

To Reproduce

https://colab.research.google.com/drive/1URZVd3q0LUeZ8anzrkoJkehOxLdbG1k8?usp=sharing

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding

xr.use_spmd()

num_devices = xr.global_runtime_device_count()
device = xm.xla_device()
device_ids = np.arange(num_devices)
mesh = xs.Mesh(device_ids, (2, 1, 2, 2), ('dp', 'fsdp', 'tp', 'sp'))

xt = torch.zeros([8, 64]).to(device)
xs.mark_sharding(xt, mesh, ('dp', 'sp'))
print(torch_xla._XLAC._get_xla_sharding_spec(xt))
print(visualize_tensor_sharding(xt))  # This is the desired sharding

xt = torch.zeros([8, 64])
yt = xm.send_cpu_data_to_device(xt, device, input_sharding=xs.ShardingSpec(mesh, ('dp', 'sp')))[0]

Steps to reproduce the behavior:

  1. run the colab

Expected behavior

run without error

Environment

Additional context

bhavya01 commented 5 days ago

Will look into this later today.