Closed guangli-dai closed 3 years ago
It is working for me
>>> n=1000
>>> dev = xm.xla_device(n=1, devkind='TPU')
>>> A = torch.randn(n,n, dtype=torch.bfloat16, device=dev)
>>> norm_sum = torch.ones(1,1, dtype=torch.bfloat16, device=dev)
>>> q, r = torch.qr(A)
>>> q
tensor([[-0.0078, 0.0674, 0.0566, ..., 0.0352, 0.0025, -0.0065],
[-0.0012, -0.0311, -0.0031, ..., -0.0039, 0.0461, -0.0339],
[-0.0045, -0.0181, -0.0466, ..., -0.0549, -0.0286, -0.0115],
...,
[ 0.0181, -0.0192, -0.0049, ..., -0.0214, 0.0177, -0.0014],
[ 0.0317, 0.0476, 0.0179, ..., -0.0194, -0.0486, -0.0332],
[ 0.0669, 0.0133, 0.0206, ..., -0.0065, -0.0026, -0.0264]],
device='xla:1', dtype=torch.bfloat16)
>>> r
tensor([[-31.0000, -2.4531, -1.0703, ..., 0.4824, 0.3086, 1.1250],
[ 0.0000, 30.6250, 0.6328, ..., 0.1357, -0.2334, 0.1465],
[ 0.0000, 0.0000, -32.5000, ..., 1.8359, 0.6094, -0.2637],
...,
[ 0.0000, 0.0000, 0.0000, ..., 1.9531, 1.7578, 0.7617],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 3.0469, 0.2314],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.6172]],
device='xla:1', dtype=torch.bfloat16)
This looks like a TPU runtime version issue. The best way to solve that is to restart your TPU node, so it will get the latest nightly run time.
Thank you! It works fine now after updating the runtime version.
❓ Questions and Help
I tried to run qr decomposition on a TPU, expecting a correct result and faster speed. However, I got the root errors saying this is not implemented. Did I miss anything when using or is this a feature to be added?
Environments: pytorch/xla compile from source (last updated on Nov 24th) TPU: v3-8
Codes:
Error messages
A more detailed log with error message:
Thank you in advance!