Open fengyang0317 opened 1 week ago
xm.send_cpu_data_to_device cannot support 2d data and 4d mesh
https://colab.research.google.com/drive/1URZVd3q0LUeZ8anzrkoJkehOxLdbG1k8?usp=sharing
import numpy as np import torch import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs from torch_xla import runtime as xr from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding xr.use_spmd() num_devices = xr.global_runtime_device_count() device = xm.xla_device() device_ids = np.arange(num_devices) mesh = xs.Mesh(device_ids, (2, 1, 2, 2), ('dp', 'fsdp', 'tp', 'sp')) xt = torch.zeros([8, 64]).to(device) xs.mark_sharding(xt, mesh, ('dp', 'sp')) print(torch_xla._XLAC._get_xla_sharding_spec(xt)) print(visualize_tensor_sharding(xt)) # This is the desired sharding xt = torch.zeros([8, 64]) yt = xm.send_cpu_data_to_device(xt, device, input_sharding=xs.ShardingSpec(mesh, ('dp', 'sp')))[0]
Steps to reproduce the behavior:
run without error
Will look into this later today.
🐛 Bug
xm.send_cpu_data_to_device cannot support 2d data and 4d mesh
To Reproduce
https://colab.research.google.com/drive/1URZVd3q0LUeZ8anzrkoJkehOxLdbG1k8?usp=sharing
Steps to reproduce the behavior:
Expected behavior
run without error
Environment
Additional context