pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Fix the fallback op in SPMD #8386

Closed JackCaoG closed 3 days ago

JackCaoG commented 6 days ago

for a simple repo

import torch
import torch.nn as nn
import math
import torch_xla
import numpy as np
import torch_xla.runtime as xr

xr.use_spmd()

device = torch_xla.device()

theta: float = 10000
dim = 16
end=2048

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs, device=device), freqs)  # complex64

print(f"before sync{print(torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))}")
torch_xla.sync()
print(f"after sync{print(torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))}")

I saw

XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[2048,8]
ShardingSpec: {replicated}
IR: None
XLAData: None
Tensor on host: with size [2048, 8]
}

XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[0]
ShardingSpec: {replicated}
IR: None
XLAShardedData: 
  Data Device: SPMD:0
  Data Shape: c64[0]
  OpSharding: {replicated}
  NumShards: 4
Tensor on host: with size [2048, 8]
}

the second one is where is error is from

XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[0]
ShardingSpec: {replicated}
IR: None
XLAShardedData: 
  Data Device: SPMD:0
  Data Shape: c64[0]
  OpSharding: {replicated}
  NumShards: 4
Tensor on host: with size [2048, 8]
}

where XLAShardedData says it is c64[0] but real tensor is actually [2048, 8]. This is because when we doing the inplace copy(triggered by fallback) does not clear the previous shardingspec which confuse the pytorch/xla since shardingspec and at::tensor does not agree on size.