pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

QR decomposition on TPU? #2689

Closed guangli-dai closed 3 years ago

guangli-dai commented 3 years ago

❓ Questions and Help

I tried to run qr decomposition on a TPU, expecting a correct result and faster speed. However, I got the root errors saying this is not implemented. Did I miss anything when using or is this a feature to be added?

Environments: pytorch/xla compile from source (last updated on Nov 24th) TPU: v3-8

Codes:

import torch                                                                                                                                                                                                                           
import torch_xla.core.xla_builder as xb                                                                                                                                                                                                
import torch_xla.core.xla_op_registry as xor                                                                                                                                                                                           
import torch_xla.core.xla_model as xm                                                                                                                                                                                                  
import torch_xla.utils.utils as xu
import time

n=1000
dev = xm.xla_device(n=1, devkind='TPU')
A = torch.randn(n,n, dtype=torch.bfloat16, device=dev)
norm_sum = torch.ones(1,1, dtype=torch.bfloat16, device=dev)
q, r = torch.qr(A)
norm_sum += torch.norm(q)
print(norm_sum)

Error messages

2 root error(s) found.
  (0) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.

Error encountered while compiling %custom-call = (bf16[1000,1000]{1,0:T(8,128)(2,1)}, bf16[1000,1000]{1,0:T(8,128)(2,1)}) custom-call(bf16[1000,1000]{1,0:T(8,128)(2,1)} %reshape.92), custom_call_target="QrDecomposition".
         [[{{node XRTCompile}}]]
  (1) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.

A more detailed log with error message:

2020-12-16 05:28:13.072373: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] >>> Dumping Computation 0
2020-12-16 05:28:13.072443: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] HloModule SyncTensorsGraph.100
2020-12-16 05:28:13.072451: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.072457: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] %AddComputation.71 (x.72: bf16[], y.73: bf16[]) -> bf16[] {
2020-12-16 05:28:13.072463: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %x.72 = bf16[] parameter(0)
2020-12-16 05:28:13.072469: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %y.73 = bf16[] parameter(1)
2020-12-16 05:28:13.072474: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   ROOT %add.74 = bf16[] add(bf16[] %x.72, bf16[] %y.73)
2020-12-16 05:28:13.072480: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] }
2020-12-16 05:28:13.072485: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.072491: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] ENTRY %SyncTensorsGraph.100 (p0.2: s64[], p1.89: f32[]) -> (f32[1], pred[1]) {
2020-12-16 05:28:13.072496: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.5 = s64[] constant(2531011)
2020-12-16 05:28:13.072502: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.3 = s64[] constant(214013)
2020-12-16 05:28:13.072507: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %p0.2 = s64[] parameter(0)
2020-12-16 05:28:13.072512: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.4 = s64[] multiply(s64[] %constant.3, s64[] %p0.2)
2020-12-16 05:28:13.072518: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %add.6 = s64[] add(s64[] %constant.5, s64[] %multiply.4)
2020-12-16 05:28:13.072523: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.17 = u64[] convert(s64[] %add.6)
2020-12-16 05:28:13.072529: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.23 = u64[1]{0} reshape(u64[] %convert.17)
2020-12-16 05:28:13.072534: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.18 = u64[] constant(0)
2020-12-16 05:28:13.072539: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.24 = u64[1]{0} reshape(u64[] %constant.18)
2020-12-16 05:28:13.072545: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %concatenate.25 = u64[2]{0} concatenate(u64[1]{0} %reshape.23, u64[1]{0} %reshape.24), dimensions={0}
2020-12-16 05:28:13.072551: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %rng-bit-generator.26 = (u64[2]{0}, u32[1000000]{0}) rng-bit-generator(u64[2]{0} %concatenate.25), algorithm=rng_default
2020-12-16 05:28:13.072557: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %get-tuple-element.28 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[1000000]{0}) %rng-bit-generator.26), index=0
2020-12-16 05:28:13.072563: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.12 = bf16[] constant(0)
2020-12-16 05:28:13.072568: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.13 = bf16[1,1]{1,0} reshape(bf16[] %constant.12)
2020-12-16 05:28:13.072574: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.14 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.13), dimensions={0,1}
2020-12-16 05:28:13.072579: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.15 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.14)
2020-12-16 05:28:13.072584: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.16 = bf16[1000,1000]{1,0} broadcast(bf16[] %reshape.15), dimensions={}
2020-12-16 05:28:13.072590: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.19 = f32[1000,1000]{1,0} convert(bf16[1000,1000]{1,0} %broadcast.16)
2020-12-16 05:28:13.072595: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.46 = f32[] constant(6.28318548)
2020-12-16 05:28:13.072601: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.47 = f32[500000]{0} broadcast(f32[] %constant.46), dimensions={}
2020-12-16 05:28:13.072606: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %get-tuple-element.27 = u32[1000000]{0} get-tuple-element((u64[2]{0}, u32[1000000]{0}) %rng-bit-generator.26), index=1
2020-12-16 05:28:13.072617: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.29 = u32[] constant(9)
2020-12-16 05:28:13.072623: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.30 = u32[1000000]{0} broadcast(u32[] %constant.29), dimensions={}
2020-12-16 05:28:13.072628: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %shift-right-logical.31 = u32[1000000]{0} shift-right-logical(u32[1000000]{0} %get-tuple-element.27, u32[1000000]{0} %broadcast.30)
2020-12-16 05:28:13.072634: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.32 = f32[1000000]{0} convert(u32[1000000]{0} %shift-right-logical.31)
2020-12-16 05:28:13.072639: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.33 = f32[] constant(1.1920929e-07)
2020-12-16 05:28:13.072645: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.34 = f32[1000000]{0} broadcast(f32[] %constant.33), dimensions={}
2020-12-16 05:28:13.072650: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.35 = f32[1000000]{0} multiply(f32[1000000]{0} %convert.32, f32[1000000]{0} %broadcast.34)
2020-12-16 05:28:13.072655: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.21 = f32[] constant(1)
2020-12-16 05:28:13.072661: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.22 = f32[] constant(0)
2020-12-16 05:28:13.072666: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %subtract.36 = f32[] subtract(f32[] %constant.21, f32[] %constant.22)
2020-12-16 05:28:13.072672: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.37 = f32[1000000]{0} broadcast(f32[] %subtract.36), dimensions={}
2020-12-16 05:28:13.072680: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.38 = f32[1000000]{0} multiply(f32[1000000]{0} %multiply.35, f32[1000000]{0} %broadcast.37)
2020-12-16 05:28:13.072686: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.39 = f32[1000000]{0} broadcast(f32[] %constant.22), dimensions={}
2020-12-16 05:28:13.072691: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %add.40 = f32[1000000]{0} add(f32[1000000]{0} %multiply.38, f32[1000000]{0} %broadcast.39)
2020-12-16 05:28:13.072697: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %slice.42 = f32[500000]{0} slice(f32[1000000]{0} %add.40), slice={[500000:1000000]}
2020-12-16 05:28:13.072702: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.48 = f32[500000]{0} multiply(f32[500000]{0} %broadcast.47, f32[500000]{0} %slice.42)
2020-12-16 05:28:13.072707: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %sine.54 = f32[500000]{0} sine(f32[500000]{0} %multiply.48)
2020-12-16 05:28:13.072713: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.50 = f32[] constant(-2)
2020-12-16 05:28:13.072718: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.51 = f32[500000]{0} broadcast(f32[] %constant.50), dimensions={}
2020-12-16 05:28:13.072724: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %slice.41 = f32[500000]{0} slice(f32[1000000]{0} %add.40), slice={[0:500000]}
2020-12-16 05:28:13.072729: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.43 = f32[] constant(1e-07)
2020-12-16 05:28:13.072734: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.44 = f32[500000]{0} broadcast(f32[] %constant.43), dimensions={}
2020-12-16 05:28:13.072740: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %maximum.45 = f32[500000]{0} maximum(f32[500000]{0} %slice.41, f32[500000]{0} %broadcast.44)
2020-12-16 05:28:13.072748: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %log.49 = f32[500000]{0} log(f32[500000]{0} %maximum.45)
2020-12-16 05:28:13.072753: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.52 = f32[500000]{0} multiply(f32[500000]{0} %broadcast.51, f32[500000]{0} %log.49)
2020-12-16 05:28:13.072759: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %sqrt.53 = f32[500000]{0} sqrt(f32[500000]{0} %multiply.52)
2020-12-16 05:28:13.072767: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.55 = f32[500000]{0} multiply(f32[500000]{0} %sine.54, f32[500000]{0} %sqrt.53)
2020-12-16 05:28:13.072773: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %cosine.56 = f32[500000]{0} cosine(f32[500000]{0} %multiply.48)
2020-12-16 05:28:13.072778: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.57 = f32[500000]{0} multiply(f32[500000]{0} %cosine.56, f32[500000]{0} %sqrt.53)
2020-12-16 05:28:13.072783: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %concatenate.58 = f32[1000000]{0} concatenate(f32[500000]{0} %multiply.55, f32[500000]{0} %multiply.57), dimensions={0}
2020-12-16 05:28:13.072789: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.59 = f32[1000,1000]{1,0} reshape(f32[1000000]{0} %concatenate.58)
2020-12-16 05:28:13.072794: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.7 = bf16[] constant(1)
2020-12-16 05:28:13.072800: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.8 = bf16[1,1]{1,0} reshape(bf16[] %constant.7)
2020-12-16 05:28:13.072805: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.9 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.8), dimensions={0,1}
2020-12-16 05:28:13.072810: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.10 = bf16[] reshape(bf16[1,1]{1,0} %broadcast.9)
2020-12-16 05:28:13.072816: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.11 = bf16[1000,1000]{1,0} broadcast(bf16[] %reshape.10), dimensions={}
2020-12-16 05:28:13.072821: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.20 = f32[1000,1000]{1,0} convert(bf16[1000,1000]{1,0} %broadcast.11)
2020-12-16 05:28:13.072827: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.60 = f32[1000,1000]{1,0} multiply(f32[1000,1000]{1,0} %reshape.59, f32[1000,1000]{1,0} %convert.20)
2020-12-16 05:28:13.072832: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %add.61 = f32[1000,1000]{1,0} add(f32[1000,1000]{1,0} %convert.19, f32[1000,1000]{1,0} %multiply.60)
2020-12-16 05:28:13.072838: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.62 = bf16[1000,1000]{1,0} convert(f32[1000,1000]{1,0} %add.61)
2020-12-16 05:28:13.072844: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %custom-call.63 = (bf16[1000,1000]{1,0}, bf16[1000,1000]{1,0}) custom-call(bf16[1000,1000]{1,0} %convert.62), custom_call_target="QrDecomposition"
2020-12-16 05:28:13.072849: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %get-tuple-element.65 = bf16[1000,1000]{1,0} get-tuple-element((bf16[1000,1000]{1,0}, bf16[1000,1000]{1,0}) %custom-call.63), index=1
2020-12-16 05:28:13.072854: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %slice.67 = bf16[1000,1000]{1,0} slice(bf16[1000,1000]{1,0} %get-tuple-element.65), slice={[0:1000], [0:1000]}
2020-12-16 05:28:13.072860: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.70 = s32[] constant(1000000)
2020-12-16 05:28:13.072865: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.78 = bf16[] constant(1)
2020-12-16 05:28:13.072871: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.79 = bf16[1,1]{1,0} reshape(bf16[] %constant.78)
2020-12-16 05:28:13.072876: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.80 = bf16[1,1]{1,0} broadcast(bf16[1,1]{1,0} %reshape.79), dimensions={0,1}
2020-12-16 05:28:13.072882: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %get-tuple-element.64 = bf16[1000,1000]{1,0} get-tuple-element((bf16[1000,1000]{1,0}, bf16[1000,1000]{1,0}) %custom-call.63), index=0
2020-12-16 05:28:13.072887: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %slice.66 = bf16[1000,1000]{1,0} slice(bf16[1000,1000]{1,0} %get-tuple-element.64), slice={[0:1000], [0:1000]}
2020-12-16 05:28:13.072892: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.68 = bf16[1000,1000]{1,0} multiply(bf16[1000,1000]{1,0} %slice.66, bf16[1000,1000]{1,0} %slice.66)
2020-12-16 05:28:13.072901: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.69 = bf16[] constant(0)
2020-12-16 05:28:13.072906: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reduce.75 = bf16[] reduce(bf16[1000,1000]{1,0} %multiply.68, bf16[] %constant.69), dimensions={0,1}, to_apply=%AddComputation.71
2020-12-16 05:28:13.072912: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %sqrt.76 = bf16[] sqrt(bf16[] %reduce.75)
2020-12-16 05:28:13.072917: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.1 = bf16[] constant(1)
2020-12-16 05:28:13.072923: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.77 = bf16[] multiply(bf16[] %sqrt.76, bf16[] %constant.1)
2020-12-16 05:28:13.072928: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.81 = bf16[1,1]{1,0} broadcast(bf16[] %multiply.77), dimensions={}
2020-12-16 05:28:13.072934: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %add.82 = bf16[1,1]{1,0} add(bf16[1,1]{1,0} %broadcast.80, bf16[1,1]{1,0} %broadcast.81)
2020-12-16 05:28:13.072939: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.83 = f32[1,1]{1,0} convert(bf16[1,1]{1,0} %add.82)
2020-12-16 05:28:13.072945: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.84 = f32[1]{0} reshape(f32[1,1]{1,0} %convert.83)
2020-12-16 05:28:13.072950: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %compare.93 = pred[1]{0} compare(f32[1]{0} %reshape.84, f32[1]{0} %reshape.84), direction=EQ
2020-12-16 05:28:13.072955: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.94 = u8[1]{0} convert(pred[1]{0} %compare.93)
2020-12-16 05:28:13.072961: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %abs.90 = f32[1]{0} abs(f32[1]{0} %reshape.84)
2020-12-16 05:28:13.072978: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %p1.89 = f32[] parameter(1)
2020-12-16 05:28:13.072984: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.91 = f32[1]{0} broadcast(f32[] %p1.89), dimensions={}
2020-12-16 05:28:13.072989: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %compare.92 = pred[1]{0} compare(f32[1]{0} %abs.90, f32[1]{0} %broadcast.91), direction=NE
2020-12-16 05:28:13.072995: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.95 = u8[1]{0} convert(pred[1]{0} %compare.92)
2020-12-16 05:28:13.073000: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.96 = u8[1]{0} multiply(u8[1]{0} %convert.94, u8[1]{0} %convert.95)
2020-12-16 05:28:13.073006: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.97 = pred[1]{0} convert(u8[1]{0} %multiply.96)
2020-12-16 05:28:13.073011: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.85 = s64[] constant(0)
2020-12-16 05:28:13.073016: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %convert.86 = f32[] convert(s64[] %constant.85)
2020-12-16 05:28:13.073022: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.87 = f32[1]{0} broadcast(f32[] %convert.86), dimensions={}
2020-12-16 05:28:13.073027: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %compare.88 = pred[1]{0} compare(f32[1]{0} %reshape.84, f32[1]{0} %broadcast.87), direction=NE
2020-12-16 05:28:13.073033: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %and.98 = pred[1]{0} and(pred[1]{0} %convert.97, pred[1]{0} %compare.88)
2020-12-16 05:28:13.073038: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   ROOT %tuple.99 = (f32[1]{0}, pred[1]{0}) tuple(f32[1]{0} %reshape.84, pred[1]{0} %and.98)
2020-12-16 05:28:13.073043: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] }
2020-12-16 05:28:13.073049: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073054: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073059: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] OutputShape: (f32[1]{0}, pred[1]{0})
2020-12-16 05:28:13.073064: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073070: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] StackTrace:
2020-12-16 05:28:13.073078: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** Begin stack trace ***
2020-12-16 05:28:13.073084: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]        tensorflow::CurrentStackTrace[abi:cxx11]()
2020-12-16 05:28:13.073090: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::util::ReportComputationError(tensorflow::Status const&, absl::lts_2020_02_25::Span<xla::XlaComputation const* const>, absl::lts_2020_02_25::Span<xla::Shape const* const>)
2020-12-16 05:28:13.073096: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::XrtComputationClient::CheckCompileStatus(tensorflow::Status const&, std::vector<xla::ComputationClient::CompileInstance, std::allocator<xla::ComputationClient::CompileInstance> > const&, xla::XrtComputationClient::SessionWork const&)
2020-12-16 05:28:13.073101: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073107: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::util::MultiWait::Complete(std::function<void ()> const&)
2020-12-16 05:28:13.073112: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073117: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073122: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073127: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]        clone
2020-12-16 05:28:13.073133: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** End stack trace ***
2020-12-16 05:28:13.073138: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073144: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] Status: Unimplemented: From /job:tpu_worker/replica:0/task:0:
2020-12-16 05:28:13.073149: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] 2 root error(s) found.
2020-12-16 05:28:13.073155: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   (0) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.
2020-12-16 05:28:13.073160: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073165: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073171: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] Error encountered while compiling %custom-call = (bf16[1000,1000]{1,0:T(8,128)(2,1)}, bf16[1000,1000]{1,0:T(8,128)(2,1)}) custom-call(bf16[1000,1000]{1,0:T(8,128)(2,1)} %reshape.92), custom_call_target="QrDecomposition".
2020-12-16 05:28:13.073177: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]         [[{{node XRTCompile}}]]
2020-12-16 05:28:13.073182: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]   (1) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.
2020-12-16 05:28:13.073188: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073193: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-12-16 05:28:13.073199: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] Error encountered while compiling %custom-call = (bf16[1000,1000]{1,0:T(8,128)(2,1)}, bf16[1000,1000]{1,0:T(8,128)(2,1)}) custom-call(bf16[1000,1000]{1,0:T(8,128)(2,1)} %reshape.92), custom_call_target="QrDecomposition".
2020-12-16 05:28:13.073204: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]         [[{{node XRTCompile}}]]
2020-12-16 05:28:13.073210: E tensorflow/compiler/xla/xla_client/xla_util.cc:76]         [[XRTCompile_G3]]
2020-12-16 05:28:13.073215: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] 0 successful operations.
2020-12-16 05:28:13.073221: E tensorflow/compiler/xla/xla_client/xla_util.cc:76] 0 derived errors ignored.
Traceback (most recent call last):
  File "test_qr.py", line 14, in <module>
    print(norm_sum)
  File "/usr/local/lib/python3.8/dist-packages/torch/tensor.py", line 179, in __repr__
    return torch._tensor_str._str(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 372, in _str
    return _str_intern(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 352, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 241, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 89, in __init__
    nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.

Error encountered while compiling %custom-call = (bf16[1000,1000]{1,0:T(8,128)(2,1)}, bf16[1000,1000]{1,0:T(8,128)(2,1)}) custom-call(bf16[1000,1000]{1,0:T(8,128)(2,1)} %reshape.92), custom_call_target="QrDecomposition".
         [[{{node XRTCompile}}]]
  (1) Unimplemented: CustomCall for 'QrDecomposition' is not implemented for TPU.

Error encountered while compiling %custom-call = (bf16[1000,1000]{1,0:T(8,128)(2,1)}, bf16[1000,1000]{1,0:T(8,128)(2,1)}) custom-call(bf16[1000,1000]{1,0:T(8,128)(2,1)} %reshape.92), custom_call_target="QrDecomposition".
         [[{{node XRTCompile}}]]
         [[XRTCompile_G3]]
0 successful operations.
0 derived errors ignored.

Thank you in advance!

JackCaoG commented 3 years ago

It is working for me

>>> n=1000
>>> dev = xm.xla_device(n=1, devkind='TPU')
>>> A = torch.randn(n,n, dtype=torch.bfloat16, device=dev)
>>> norm_sum = torch.ones(1,1, dtype=torch.bfloat16, device=dev)
>>> q, r = torch.qr(A)
>>> q
tensor([[-0.0078,  0.0674,  0.0566,  ...,  0.0352,  0.0025, -0.0065],
        [-0.0012, -0.0311, -0.0031,  ..., -0.0039,  0.0461, -0.0339],
        [-0.0045, -0.0181, -0.0466,  ..., -0.0549, -0.0286, -0.0115],
        ...,
        [ 0.0181, -0.0192, -0.0049,  ..., -0.0214,  0.0177, -0.0014],
        [ 0.0317,  0.0476,  0.0179,  ..., -0.0194, -0.0486, -0.0332],
        [ 0.0669,  0.0133,  0.0206,  ..., -0.0065, -0.0026, -0.0264]],
       device='xla:1', dtype=torch.bfloat16)
>>> r
tensor([[-31.0000,  -2.4531,  -1.0703,  ...,   0.4824,   0.3086,   1.1250],
        [  0.0000,  30.6250,   0.6328,  ...,   0.1357,  -0.2334,   0.1465],
        [  0.0000,   0.0000, -32.5000,  ...,   1.8359,   0.6094,  -0.2637],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   1.9531,   1.7578,   0.7617],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   3.0469,   0.2314],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,  -0.6172]],
       device='xla:1', dtype=torch.bfloat16)

This looks like a TPU runtime version issue. The best way to solve that is to restart your TPU node, so it will get the latest nightly run time.

guangli-dai commented 3 years ago

Thank you! It works fine now after updating the runtime version.