import torch
import torch.nn as nn
import math
import torch_xla
import numpy as np
import torch_xla.runtime as xr
xr.use_spmd()
device = torch_xla.device()
theta: float = 10000
dim = 16
end=2048
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs, device=device), freqs) # complex64
print(f"before sync{print(torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))}")
torch_xla.sync()
print(f"after sync{print(torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))}")
I saw
XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[2048,8]
ShardingSpec: {replicated}
IR: None
XLAData: None
Tensor on host: with size [2048, 8]
}
XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[0]
ShardingSpec: {replicated}
IR: None
XLAShardedData:
Data Device: SPMD:0
Data Shape: c64[0]
OpSharding: {replicated}
NumShards: 4
Tensor on host: with size [2048, 8]
}
the second one is where is error is from
XLATensor {
TensorID: 19
Device: SPMD:0
XLA Shape: c64[0]
ShardingSpec: {replicated}
IR: None
XLAShardedData:
Data Device: SPMD:0
Data Shape: c64[0]
OpSharding: {replicated}
NumShards: 4
Tensor on host: with size [2048, 8]
}
where XLAShardedData says it is c64[0] but real tensor is actually [2048, 8]. This is because when we doing the inplace copy(triggered by fallback) does not clear the previous shardingspec which confuse the pytorch/xla since shardingspec and at::tensor does not agree on size.
for a simple repo
I saw
the second one is where is error is from
where
XLAShardedData
says it isc64[0]
but real tensor is actually[2048, 8]
. This is because when we doing the inplace copy(triggered by fallback) does not clear the previousshardingspec
which confuse the pytorch/xla sinceshardingspec
andat::tensor
does not agree on size.