pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Einsum does not get dispatched to `xla::einsum` #4023

Closed JackCaoG closed 2 years ago

JackCaoG commented 2 years ago

🐛 Bug

with https://github.com/pytorch/xla/commit/aa73509890fb10b7f3bc2b9e8598f96382f2c979, torch.einsum should be dispatched to xla::einsum, but it is not

To Reproduce

using nightly and

import torch
import torch_xla

res2 = torch.einsum('ii->i', torch.randn(4, 4, device='xla:0'))
print(torch_xla._XLAC._get_xla_tensors_text([res2]))

import torch_xla.debug.metrics as met; print(met.metrics_report())

I saw

IR {
  %0 = s64[] xla::device_data(), location=<module>@test_einsum.py:4, device=CPU:0
  %1 = s64[] prim::Constant(), location=<module>@test_einsum.py:4, value=214013
  %2 = s64[] aten::mul(%1, %0), location=<module>@test_einsum.py:4
  %3 = s64[] prim::Constant(), location=<module>@test_einsum.py:4, value=2531011
  %4 = s64[] aten::add(%3, %2), location=<module>@test_einsum.py:4
  %5 = f32[] prim::Constant(), location=<module>@test_einsum.py:4, value=1
  %6 = f32[4,4]{1,0} aten::expand(%5), location=<module>@test_einsum.py:4, size=(4, 4)
  %7 = f32[] prim::Constant(), location=<module>@test_einsum.py:4, value=0
  %8 = f32[4,4]{1,0} aten::expand(%7), location=<module>@test_einsum.py:4, size=(4, 4)
  %9 = f32[4,4]{1,0} aten::normal(%8, %6, %4), location=<module>@test_einsum.py:4
  %10 = f32[4]{0} aten::diagonal(%9), location=<module>@test_einsum.py:5, offset=0, dim1=0, dim2=1
  %11 = f32[4]{0} aten::permute(%10), location=<module>@test_einsum.py:5, dims=(0)
  %12 = f32[4]{0} aten::permute(%11), location=<module>@test_einsum.py:5, dims=(0), ROOT=0
}

...

Counter: xla::diagonal
  Value: 1
Counter: xla::einsum
  Value: 1
Counter: xla::empty_symint
  Value: 1
Counter: xla::normal_
  Value: 1
Counter: xla::permute
  Value: 2

it seems like einsum incorrectly fallback in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065

JackCaoG commented 2 years ago

@steventk-g @ronghanghu FYI

JackCaoG commented 2 years ago

If I try

x = torch.randn(5, device='xla:0')
y = torch.randn(4, device='xla:0')
res3 = torch.einsum('i,j->ij', x, y)
print(torch_xla._XLAC._get_xla_tensors_text([res3]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res3]))

I do see aten::einsum in

IR {
  %0 = s64[] xla::device_data(), location=<module>@test_einsum.py:10, device=CPU:0
  %1 = s64[] prim::Constant(), location=<module>@test_einsum.py:10, value=214013
  %2 = s64[] aten::mul(%1, %0), location=<module>@test_einsum.py:10
  %3 = s64[] prim::Constant(), location=<module>@test_einsum.py:10, value=2531011
  %4 = s64[] aten::add(%3, %2), location=<module>@test_einsum.py:10
  %5 = s64[] prim::Constant(), location=<module>@test_einsum.py:11, value=214013
  %6 = s64[] aten::mul(%5, %4), location=<module>@test_einsum.py:11
  %7 = s64[] prim::Constant(), location=<module>@test_einsum.py:11, value=2531011
  %8 = s64[] aten::add(%7, %6), location=<module>@test_einsum.py:11
  %9 = f32[] prim::Constant(), location=<module>@test_einsum.py:11, value=1
  %10 = f32[4]{0} aten::expand(%9), location=<module>@test_einsum.py:11, size=(4)
  %11 = f32[] prim::Constant(), location=<module>@test_einsum.py:11, value=0
  %12 = f32[4]{0} aten::expand(%11), location=<module>@test_einsum.py:11, size=(4)
  %13 = f32[4]{0} aten::normal(%12, %10, %8), location=<module>@test_einsum.py:11
  %14 = f32[] prim::Constant(), location=<module>@test_einsum.py:10, value=1
  %15 = f32[5]{0} aten::expand(%14), location=<module>@test_einsum.py:10, size=(5)
  %16 = f32[] prim::Constant(), location=<module>@test_einsum.py:10, value=0
  %17 = f32[5]{0} aten::expand(%16), location=<module>@test_einsum.py:10, size=(5)
  %18 = f32[5]{0} aten::normal(%17, %15, %4), location=<module>@test_einsum.py:10
  %19 = f32[5,4]{1,0} aten::einsum(%18, %13), location=einsum@functional.py:364, equation=(i,j->ij), ROOT=0
}

but the HLO still does not have xla::einsum

HloModule IrToHlo.120, entry_computation_layout={(s64[])->(f32[5,4]{1,0})}

ENTRY %IrToHlo.120 (p0.1: s64[]) -> (f32[5,4]) {
  %constant.8 = s64[] constant(2531011), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
  %constant.6 = s64[] constant(214013), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
  %constant.4 = s64[] constant(2531011), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
  %constant.2 = s64[] constant(214013), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
  %p0.1 = s64[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.3 = s64[] multiply(s64[] %constant.2, s64[] %p0.1), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_einsum.py" source_line=10}
  %add.5 = s64[] add(s64[] %constant.4, s64[] %multiply.3), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.7 = s64[] multiply(s64[] %constant.6, s64[] %add.5), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_einsum.py" source_line=11}
  %add.9 = s64[] add(s64[] %constant.8, s64[] %multiply.7), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_einsum.py" source_line=11}
  %convert.20 = u64[] convert(s64[] %add.9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.24 = u64[1]{0} reshape(u64[] %convert.20), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.21 = u64[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.25 = u64[1]{0} reshape(u64[] %constant.21), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %concatenate.26 = u64[2]{0} concatenate(u64[1]{0} %reshape.24, u64[1]{0} %reshape.25), dimensions={0}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %rng-bit-generator.27 = (u64[2]{0}, u32[2,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.26), algorithm=rng_default, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %get-tuple-element.29 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[2,2]{1,0}) %rng-bit-generator.27), index=0, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %convert.73 = u64[] convert(s64[] %add.5), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.77 = u64[1]{0} reshape(u64[] %convert.73), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.74 = u64[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.78 = u64[1]{0} reshape(u64[] %constant.74), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %concatenate.79 = u64[2]{0} concatenate(u64[1]{0} %reshape.77, u64[1]{0} %reshape.78), dimensions={0}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %rng-bit-generator.80 = (u64[2]{0}, u32[3,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.79), algorithm=rng_default, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %get-tuple-element.82 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[3,2]{1,0}) %rng-bit-generator.80), index=0, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.68 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.69 = f32[1]{0} reshape(f32[] %constant.68), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.70 = f32[1]{0} broadcast(f32[1]{0} %reshape.69), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.71 = f32[] reshape(f32[1]{0} %broadcast.70), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.72 = f32[5]{0} broadcast(f32[] %reshape.71), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %constant.100 = f32[] constant(6.28318548), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.101 = f32[3,1]{1,0} broadcast(f32[] %constant.100), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %get-tuple-element.81 = u32[3,2]{1,0} get-tuple-element((u64[2]{0}, u32[3,2]{1,0}) %rng-bit-generator.80), index=1, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.83 = u32[] constant(9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.84 = u32[3,2]{1,0} broadcast(u32[] %constant.83), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %shift-right-logical.85 = u32[3,2]{1,0} shift-right-logical(u32[3,2]{1,0} %get-tuple-element.81, u32[3,2]{1,0} %broadcast.84), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %convert.86 = f32[3,2]{1,0} convert(u32[3,2]{1,0} %shift-right-logical.85), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.87 = f32[] constant(1.1920929e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.88 = f32[3,2]{1,0} broadcast(f32[] %constant.87), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.89 = f32[3,2]{1,0} multiply(f32[3,2]{1,0} %convert.86, f32[3,2]{1,0} %broadcast.88), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.75 = f32[] constant(1), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.76 = f32[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %subtract.90 = f32[] subtract(f32[] %constant.75, f32[] %constant.76), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.91 = f32[3,2]{1,0} broadcast(f32[] %subtract.90), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.92 = f32[3,2]{1,0} multiply(f32[3,2]{1,0} %multiply.89, f32[3,2]{1,0} %broadcast.91), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.93 = f32[3,2]{1,0} broadcast(f32[] %constant.76), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %add.94 = f32[3,2]{1,0} add(f32[3,2]{1,0} %multiply.92, f32[3,2]{1,0} %broadcast.93), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %slice.96 = f32[3,1]{1,0} slice(f32[3,2]{1,0} %add.94), slice={[0:3], [1:2]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.102 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %broadcast.101, f32[3,1]{1,0} %slice.96), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %sine.108 = f32[3,1]{1,0} sine(f32[3,1]{1,0} %multiply.102), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.104 = f32[] constant(-2), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.105 = f32[3,1]{1,0} broadcast(f32[] %constant.104), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %slice.95 = f32[3,1]{1,0} slice(f32[3,2]{1,0} %add.94), slice={[0:3], [0:1]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.97 = f32[] constant(1e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.98 = f32[3,1]{1,0} broadcast(f32[] %constant.97), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %maximum.99 = f32[3,1]{1,0} maximum(f32[3,1]{1,0} %slice.95, f32[3,1]{1,0} %broadcast.98), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %log.103 = f32[3,1]{1,0} log(f32[3,1]{1,0} %maximum.99), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.106 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %broadcast.105, f32[3,1]{1,0} %log.103), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %sqrt.107 = f32[3,1]{1,0} sqrt(f32[3,1]{1,0} %multiply.106), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.109 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %sine.108, f32[3,1]{1,0} %sqrt.107), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %cosine.110 = f32[3,1]{1,0} cosine(f32[3,1]{1,0} %multiply.102), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.111 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %cosine.110, f32[3,1]{1,0} %sqrt.107), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %concatenate.112 = f32[3,2]{1,0} concatenate(f32[3,1]{1,0} %multiply.109, f32[3,1]{1,0} %multiply.111), dimensions={1}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.113 = f32[6]{0} reshape(f32[3,2]{1,0} %concatenate.112), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %slice.114 = f32[5]{0} slice(f32[6]{0} %reshape.113), slice={[0:5]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.63 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.64 = f32[1]{0} reshape(f32[] %constant.63), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.65 = f32[1]{0} broadcast(f32[1]{0} %reshape.64), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %reshape.66 = f32[] reshape(f32[1]{0} %broadcast.65), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %broadcast.67 = f32[5]{0} broadcast(f32[] %reshape.66), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
  %multiply.115 = f32[5]{0} multiply(f32[5]{0} %slice.114, f32[5]{0} %broadcast.67), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %add.116 = f32[5]{0} add(f32[5]{0} %broadcast.72, f32[5]{0} %multiply.115), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
  %constant.15 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.16 = f32[1]{0} reshape(f32[] %constant.15), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.17 = f32[1]{0} broadcast(f32[1]{0} %reshape.16), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.18 = f32[] reshape(f32[1]{0} %broadcast.17), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.19 = f32[4]{0} broadcast(f32[] %reshape.18), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %constant.47 = f32[] constant(6.28318548), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.48 = f32[2,1]{1,0} broadcast(f32[] %constant.47), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %get-tuple-element.28 = u32[2,2]{1,0} get-tuple-element((u64[2]{0}, u32[2,2]{1,0}) %rng-bit-generator.27), index=1, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.30 = u32[] constant(9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.31 = u32[2,2]{1,0} broadcast(u32[] %constant.30), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %shift-right-logical.32 = u32[2,2]{1,0} shift-right-logical(u32[2,2]{1,0} %get-tuple-element.28, u32[2,2]{1,0} %broadcast.31), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %convert.33 = f32[2,2]{1,0} convert(u32[2,2]{1,0} %shift-right-logical.32), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.34 = f32[] constant(1.1920929e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.35 = f32[2,2]{1,0} broadcast(f32[] %constant.34), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.36 = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %convert.33, f32[2,2]{1,0} %broadcast.35), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.22 = f32[] constant(1), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.23 = f32[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %subtract.37 = f32[] subtract(f32[] %constant.22, f32[] %constant.23), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.38 = f32[2,2]{1,0} broadcast(f32[] %subtract.37), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.39 = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %multiply.36, f32[2,2]{1,0} %broadcast.38), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.40 = f32[2,2]{1,0} broadcast(f32[] %constant.23), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %add.41 = f32[2,2]{1,0} add(f32[2,2]{1,0} %multiply.39, f32[2,2]{1,0} %broadcast.40), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %slice.43 = f32[2,1]{1,0} slice(f32[2,2]{1,0} %add.41), slice={[0:2], [1:2]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.49 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %broadcast.48, f32[2,1]{1,0} %slice.43), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %sine.55 = f32[2,1]{1,0} sine(f32[2,1]{1,0} %multiply.49), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.51 = f32[] constant(-2), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.52 = f32[2,1]{1,0} broadcast(f32[] %constant.51), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %slice.42 = f32[2,1]{1,0} slice(f32[2,2]{1,0} %add.41), slice={[0:2], [0:1]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.44 = f32[] constant(1e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.45 = f32[2,1]{1,0} broadcast(f32[] %constant.44), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %maximum.46 = f32[2,1]{1,0} maximum(f32[2,1]{1,0} %slice.42, f32[2,1]{1,0} %broadcast.45), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %log.50 = f32[2,1]{1,0} log(f32[2,1]{1,0} %maximum.46), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.53 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %broadcast.52, f32[2,1]{1,0} %log.50), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %sqrt.54 = f32[2,1]{1,0} sqrt(f32[2,1]{1,0} %multiply.53), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.56 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %sine.55, f32[2,1]{1,0} %sqrt.54), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %cosine.57 = f32[2,1]{1,0} cosine(f32[2,1]{1,0} %multiply.49), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.58 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %cosine.57, f32[2,1]{1,0} %sqrt.54), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %concatenate.59 = f32[2,2]{1,0} concatenate(f32[2,1]{1,0} %multiply.56, f32[2,1]{1,0} %multiply.58), dimensions={1}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.60 = f32[4]{0} reshape(f32[2,2]{1,0} %concatenate.59), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %constant.10 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.11 = f32[1]{0} reshape(f32[] %constant.10), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.12 = f32[1]{0} broadcast(f32[1]{0} %reshape.11), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %reshape.13 = f32[] reshape(f32[1]{0} %broadcast.12), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %broadcast.14 = f32[4]{0} broadcast(f32[] %reshape.13), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
  %multiply.61 = f32[4]{0} multiply(f32[4]{0} %reshape.60, f32[4]{0} %broadcast.14), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %add.62 = f32[4]{0} add(f32[4]{0} %broadcast.19, f32[4]{0} %multiply.61), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
  %dot.117 = f32[5,4]{1,0} dot(f32[5]{0} %add.116, f32[4]{0} %add.62), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  %transpose.118 = f32[5,4]{1,0} transpose(f32[5,4]{1,0} %dot.117), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  ROOT %tuple.119 = (f32[5,4]{1,0}) tuple(f32[5,4]{1,0} %transpose.118)
}
JackCaoG commented 2 years ago

clean up the test a bit

x = torch.tensor([1,2,3,4,5], device='xla:0')
y = torch.tensor([2,3,4,5], device='xla:0')

res3 = torch.einsum('i,j->ij', x, y)
print(torch_xla._XLAC._get_xla_tensors_text([res3]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res3]))

I saw

IR {
  %0 = s64[4]{0} xla::device_data(), location=<module>@test_einsum.py:13, device=CPU:0
  %1 = s64[5]{0} xla::device_data(), location=<module>@test_einsum.py:12, device=CPU:0
  %2 = s64[5,4]{1,0} aten::einsum(%1, %0), location=einsum@functional.py:364, equation=(i,j->ij), ROOT=0
}

and

HloModule IrToHlo.6, entry_computation_layout={(s64[4]{0},s64[5]{0})->(s64[5,4]{1,0})}

ENTRY %IrToHlo.6 (p0.1: s64[4], p1.2: s64[5]) -> (s64[5,4]) {
  %p1.2 = s64[5]{0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=12}
  %p0.1 = s64[4]{0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=13}
  %dot.3 = s64[5,4]{1,0} dot(s64[5]{0} %p1.2, s64[4]{0} %p0.1), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  %transpose.4 = s64[5,4]{1,0} transpose(s64[5,4]{1,0} %dot.3), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  ROOT %tuple.5 = (s64[5,4]{1,0}) tuple(s64[5,4]{1,0} %transpose.4)
}

and I can confirm that xla::einsum is being called. I guess I mistakenly think einsum is a xlaop but it is actully a helper function which will call DotGeneral which is optimal on XLA.

JackCaoG commented 2 years ago

I tried another one

As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)

As = As.to('xla:0')
Bs = Bs.to('xla:0')
res4 = torch.einsum('bij,bjk->bik', As, Bs)

print(torch_xla._XLAC._get_xla_tensors_text([res4]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res4]))

and saw

IR {
  %0 = f32[3,5,4]{2,1,0} xla::device_data(), location=<module>@test_einsum.py:23, device=CPU:0
  %1 = f32[3,2,5]{2,1,0} xla::device_data(), location=<module>@test_einsum.py:22, device=CPU:0
  %2 = f32[3,2,4]{2,1,0} aten::einsum(%1, %0), location=einsum@functional.py:364, equation=(bij,bjk->bik), ROOT=0
}

and

HloModule IrToHlo.6, entry_computation_layout={(f32[3,5,4]{2,1,0},f32[3,2,5]{2,1,0})->(f32[3,2,4]{2,1,0})}

ENTRY %IrToHlo.6 (p0.1: f32[3,5,4], p1.2: f32[3,2,5]) -> (f32[3,2,4]) {
  %p1.2 = f32[3,2,5]{2,1,0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=22}
  %p0.1 = f32[3,5,4]{2,1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=23}
  %dot.3 = f32[3,2,4]{2,1,0} dot(f32[3,2,5]{2,1,0} %p1.2, f32[3,5,4]{2,1,0} %p0.1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  %transpose.4 = f32[3,2,4]{2,1,0} transpose(f32[3,2,4]{2,1,0} %dot.3), dimensions={0,1,2}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
  ROOT %tuple.5 = (f32[3,2,4]{2,1,0}) tuple(f32[3,2,4]{2,1,0} %transpose.4)
}

Which looks right too, I guess the only problematic one is 'ii->i'

steventk-g commented 2 years ago

We should enforce that xla::einsum is called when appropriate in our unit tests, e.g. for ii->i https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor.cpp#L3888

Is ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); sufficient to guarantee that the dispatch is correct? Or is there a more accurate check I can add to these unit tests?

JackCaoG commented 2 years ago

ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); won't be enough, even if we fall back in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065, we still increment the einsum coutner because it happens in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1060

What you want to do is to add a counter in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065 using

XLA_COUNTER("EinsumFallback", 1);

and make sure fallback counter is not incremented for the test

steventk-g commented 2 years ago

Got it, I'll add that check and investigate this issue. Thanks for the context!

steventk-g commented 2 years ago

Closing this after https://github.com/pytorch/xla/pull/4027

Remaining einsum edge cases are tracked in https://github.com/pytorch/xla/issues/4032