Closed JackCaoG closed 2 years ago
@steventk-g @ronghanghu FYI
If I try
x = torch.randn(5, device='xla:0')
y = torch.randn(4, device='xla:0')
res3 = torch.einsum('i,j->ij', x, y)
print(torch_xla._XLAC._get_xla_tensors_text([res3]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res3]))
I do see aten::einsum
in
IR {
%0 = s64[] xla::device_data(), location=<module>@test_einsum.py:10, device=CPU:0
%1 = s64[] prim::Constant(), location=<module>@test_einsum.py:10, value=214013
%2 = s64[] aten::mul(%1, %0), location=<module>@test_einsum.py:10
%3 = s64[] prim::Constant(), location=<module>@test_einsum.py:10, value=2531011
%4 = s64[] aten::add(%3, %2), location=<module>@test_einsum.py:10
%5 = s64[] prim::Constant(), location=<module>@test_einsum.py:11, value=214013
%6 = s64[] aten::mul(%5, %4), location=<module>@test_einsum.py:11
%7 = s64[] prim::Constant(), location=<module>@test_einsum.py:11, value=2531011
%8 = s64[] aten::add(%7, %6), location=<module>@test_einsum.py:11
%9 = f32[] prim::Constant(), location=<module>@test_einsum.py:11, value=1
%10 = f32[4]{0} aten::expand(%9), location=<module>@test_einsum.py:11, size=(4)
%11 = f32[] prim::Constant(), location=<module>@test_einsum.py:11, value=0
%12 = f32[4]{0} aten::expand(%11), location=<module>@test_einsum.py:11, size=(4)
%13 = f32[4]{0} aten::normal(%12, %10, %8), location=<module>@test_einsum.py:11
%14 = f32[] prim::Constant(), location=<module>@test_einsum.py:10, value=1
%15 = f32[5]{0} aten::expand(%14), location=<module>@test_einsum.py:10, size=(5)
%16 = f32[] prim::Constant(), location=<module>@test_einsum.py:10, value=0
%17 = f32[5]{0} aten::expand(%16), location=<module>@test_einsum.py:10, size=(5)
%18 = f32[5]{0} aten::normal(%17, %15, %4), location=<module>@test_einsum.py:10
%19 = f32[5,4]{1,0} aten::einsum(%18, %13), location=einsum@functional.py:364, equation=(i,j->ij), ROOT=0
}
but the HLO still does not have xla::einsum
HloModule IrToHlo.120, entry_computation_layout={(s64[])->(f32[5,4]{1,0})}
ENTRY %IrToHlo.120 (p0.1: s64[]) -> (f32[5,4]) {
%constant.8 = s64[] constant(2531011), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
%constant.6 = s64[] constant(214013), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
%constant.4 = s64[] constant(2531011), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
%constant.2 = s64[] constant(214013), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
%p0.1 = s64[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=10}
%multiply.3 = s64[] multiply(s64[] %constant.2, s64[] %p0.1), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_einsum.py" source_line=10}
%add.5 = s64[] add(s64[] %constant.4, s64[] %multiply.3), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_einsum.py" source_line=10}
%multiply.7 = s64[] multiply(s64[] %constant.6, s64[] %add.5), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_einsum.py" source_line=11}
%add.9 = s64[] add(s64[] %constant.8, s64[] %multiply.7), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_einsum.py" source_line=11}
%convert.20 = u64[] convert(s64[] %add.9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%reshape.24 = u64[1]{0} reshape(u64[] %convert.20), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.21 = u64[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%reshape.25 = u64[1]{0} reshape(u64[] %constant.21), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%concatenate.26 = u64[2]{0} concatenate(u64[1]{0} %reshape.24, u64[1]{0} %reshape.25), dimensions={0}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%rng-bit-generator.27 = (u64[2]{0}, u32[2,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.26), algorithm=rng_default, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%get-tuple-element.29 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[2,2]{1,0}) %rng-bit-generator.27), index=0, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%convert.73 = u64[] convert(s64[] %add.5), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%reshape.77 = u64[1]{0} reshape(u64[] %convert.73), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.74 = u64[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%reshape.78 = u64[1]{0} reshape(u64[] %constant.74), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%concatenate.79 = u64[2]{0} concatenate(u64[1]{0} %reshape.77, u64[1]{0} %reshape.78), dimensions={0}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%rng-bit-generator.80 = (u64[2]{0}, u32[3,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.79), algorithm=rng_default, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%get-tuple-element.82 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[3,2]{1,0}) %rng-bit-generator.80), index=0, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.68 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
%reshape.69 = f32[1]{0} reshape(f32[] %constant.68), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.70 = f32[1]{0} broadcast(f32[1]{0} %reshape.69), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%reshape.71 = f32[] reshape(f32[1]{0} %broadcast.70), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.72 = f32[5]{0} broadcast(f32[] %reshape.71), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%constant.100 = f32[] constant(6.28318548), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.101 = f32[3,1]{1,0} broadcast(f32[] %constant.100), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%get-tuple-element.81 = u32[3,2]{1,0} get-tuple-element((u64[2]{0}, u32[3,2]{1,0}) %rng-bit-generator.80), index=1, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.83 = u32[] constant(9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.84 = u32[3,2]{1,0} broadcast(u32[] %constant.83), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%shift-right-logical.85 = u32[3,2]{1,0} shift-right-logical(u32[3,2]{1,0} %get-tuple-element.81, u32[3,2]{1,0} %broadcast.84), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%convert.86 = f32[3,2]{1,0} convert(u32[3,2]{1,0} %shift-right-logical.85), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.87 = f32[] constant(1.1920929e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.88 = f32[3,2]{1,0} broadcast(f32[] %constant.87), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.89 = f32[3,2]{1,0} multiply(f32[3,2]{1,0} %convert.86, f32[3,2]{1,0} %broadcast.88), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.75 = f32[] constant(1), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.76 = f32[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%subtract.90 = f32[] subtract(f32[] %constant.75, f32[] %constant.76), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.91 = f32[3,2]{1,0} broadcast(f32[] %subtract.90), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.92 = f32[3,2]{1,0} multiply(f32[3,2]{1,0} %multiply.89, f32[3,2]{1,0} %broadcast.91), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.93 = f32[3,2]{1,0} broadcast(f32[] %constant.76), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%add.94 = f32[3,2]{1,0} add(f32[3,2]{1,0} %multiply.92, f32[3,2]{1,0} %broadcast.93), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%slice.96 = f32[3,1]{1,0} slice(f32[3,2]{1,0} %add.94), slice={[0:3], [1:2]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.102 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %broadcast.101, f32[3,1]{1,0} %slice.96), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%sine.108 = f32[3,1]{1,0} sine(f32[3,1]{1,0} %multiply.102), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.104 = f32[] constant(-2), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.105 = f32[3,1]{1,0} broadcast(f32[] %constant.104), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%slice.95 = f32[3,1]{1,0} slice(f32[3,2]{1,0} %add.94), slice={[0:3], [0:1]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.97 = f32[] constant(1e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.98 = f32[3,1]{1,0} broadcast(f32[] %constant.97), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%maximum.99 = f32[3,1]{1,0} maximum(f32[3,1]{1,0} %slice.95, f32[3,1]{1,0} %broadcast.98), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%log.103 = f32[3,1]{1,0} log(f32[3,1]{1,0} %maximum.99), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.106 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %broadcast.105, f32[3,1]{1,0} %log.103), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%sqrt.107 = f32[3,1]{1,0} sqrt(f32[3,1]{1,0} %multiply.106), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.109 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %sine.108, f32[3,1]{1,0} %sqrt.107), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%cosine.110 = f32[3,1]{1,0} cosine(f32[3,1]{1,0} %multiply.102), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%multiply.111 = f32[3,1]{1,0} multiply(f32[3,1]{1,0} %cosine.110, f32[3,1]{1,0} %sqrt.107), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%concatenate.112 = f32[3,2]{1,0} concatenate(f32[3,1]{1,0} %multiply.109, f32[3,1]{1,0} %multiply.111), dimensions={1}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%reshape.113 = f32[6]{0} reshape(f32[3,2]{1,0} %concatenate.112), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%slice.114 = f32[5]{0} slice(f32[6]{0} %reshape.113), slice={[0:5]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.63 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=10}
%reshape.64 = f32[1]{0} reshape(f32[] %constant.63), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.65 = f32[1]{0} broadcast(f32[1]{0} %reshape.64), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%reshape.66 = f32[] reshape(f32[1]{0} %broadcast.65), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%broadcast.67 = f32[5]{0} broadcast(f32[] %reshape.66), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=10}
%multiply.115 = f32[5]{0} multiply(f32[5]{0} %slice.114, f32[5]{0} %broadcast.67), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%add.116 = f32[5]{0} add(f32[5]{0} %broadcast.72, f32[5]{0} %multiply.115), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=10}
%constant.15 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
%reshape.16 = f32[1]{0} reshape(f32[] %constant.15), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.17 = f32[1]{0} broadcast(f32[1]{0} %reshape.16), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%reshape.18 = f32[] reshape(f32[1]{0} %broadcast.17), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.19 = f32[4]{0} broadcast(f32[] %reshape.18), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%constant.47 = f32[] constant(6.28318548), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.48 = f32[2,1]{1,0} broadcast(f32[] %constant.47), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%get-tuple-element.28 = u32[2,2]{1,0} get-tuple-element((u64[2]{0}, u32[2,2]{1,0}) %rng-bit-generator.27), index=1, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.30 = u32[] constant(9), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.31 = u32[2,2]{1,0} broadcast(u32[] %constant.30), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%shift-right-logical.32 = u32[2,2]{1,0} shift-right-logical(u32[2,2]{1,0} %get-tuple-element.28, u32[2,2]{1,0} %broadcast.31), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%convert.33 = f32[2,2]{1,0} convert(u32[2,2]{1,0} %shift-right-logical.32), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.34 = f32[] constant(1.1920929e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.35 = f32[2,2]{1,0} broadcast(f32[] %constant.34), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.36 = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %convert.33, f32[2,2]{1,0} %broadcast.35), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.22 = f32[] constant(1), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.23 = f32[] constant(0), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%subtract.37 = f32[] subtract(f32[] %constant.22, f32[] %constant.23), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.38 = f32[2,2]{1,0} broadcast(f32[] %subtract.37), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.39 = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %multiply.36, f32[2,2]{1,0} %broadcast.38), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.40 = f32[2,2]{1,0} broadcast(f32[] %constant.23), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%add.41 = f32[2,2]{1,0} add(f32[2,2]{1,0} %multiply.39, f32[2,2]{1,0} %broadcast.40), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%slice.43 = f32[2,1]{1,0} slice(f32[2,2]{1,0} %add.41), slice={[0:2], [1:2]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.49 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %broadcast.48, f32[2,1]{1,0} %slice.43), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%sine.55 = f32[2,1]{1,0} sine(f32[2,1]{1,0} %multiply.49), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.51 = f32[] constant(-2), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.52 = f32[2,1]{1,0} broadcast(f32[] %constant.51), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%slice.42 = f32[2,1]{1,0} slice(f32[2,2]{1,0} %add.41), slice={[0:2], [0:1]}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.44 = f32[] constant(1e-07), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.45 = f32[2,1]{1,0} broadcast(f32[] %constant.44), dimensions={}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%maximum.46 = f32[2,1]{1,0} maximum(f32[2,1]{1,0} %slice.42, f32[2,1]{1,0} %broadcast.45), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%log.50 = f32[2,1]{1,0} log(f32[2,1]{1,0} %maximum.46), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.53 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %broadcast.52, f32[2,1]{1,0} %log.50), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%sqrt.54 = f32[2,1]{1,0} sqrt(f32[2,1]{1,0} %multiply.53), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.56 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %sine.55, f32[2,1]{1,0} %sqrt.54), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%cosine.57 = f32[2,1]{1,0} cosine(f32[2,1]{1,0} %multiply.49), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%multiply.58 = f32[2,1]{1,0} multiply(f32[2,1]{1,0} %cosine.57, f32[2,1]{1,0} %sqrt.54), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%concatenate.59 = f32[2,2]{1,0} concatenate(f32[2,1]{1,0} %multiply.56, f32[2,1]{1,0} %multiply.58), dimensions={1}, metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%reshape.60 = f32[4]{0} reshape(f32[2,2]{1,0} %concatenate.59), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%constant.10 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_einsum.py" source_line=11}
%reshape.11 = f32[1]{0} reshape(f32[] %constant.10), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.12 = f32[1]{0} broadcast(f32[1]{0} %reshape.11), dimensions={0}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%reshape.13 = f32[] reshape(f32[1]{0} %broadcast.12), metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%broadcast.14 = f32[4]{0} broadcast(f32[] %reshape.13), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="<module>@test_einsum.py" source_line=11}
%multiply.61 = f32[4]{0} multiply(f32[4]{0} %reshape.60, f32[4]{0} %broadcast.14), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%add.62 = f32[4]{0} add(f32[4]{0} %broadcast.19, f32[4]{0} %multiply.61), metadata={op_type="aten__normal" op_name="aten__normal" source_file="<module>@test_einsum.py" source_line=11}
%dot.117 = f32[5,4]{1,0} dot(f32[5]{0} %add.116, f32[4]{0} %add.62), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
%transpose.118 = f32[5,4]{1,0} transpose(f32[5,4]{1,0} %dot.117), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
ROOT %tuple.119 = (f32[5,4]{1,0}) tuple(f32[5,4]{1,0} %transpose.118)
}
clean up the test a bit
x = torch.tensor([1,2,3,4,5], device='xla:0')
y = torch.tensor([2,3,4,5], device='xla:0')
res3 = torch.einsum('i,j->ij', x, y)
print(torch_xla._XLAC._get_xla_tensors_text([res3]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res3]))
I saw
IR {
%0 = s64[4]{0} xla::device_data(), location=<module>@test_einsum.py:13, device=CPU:0
%1 = s64[5]{0} xla::device_data(), location=<module>@test_einsum.py:12, device=CPU:0
%2 = s64[5,4]{1,0} aten::einsum(%1, %0), location=einsum@functional.py:364, equation=(i,j->ij), ROOT=0
}
and
HloModule IrToHlo.6, entry_computation_layout={(s64[4]{0},s64[5]{0})->(s64[5,4]{1,0})}
ENTRY %IrToHlo.6 (p0.1: s64[4], p1.2: s64[5]) -> (s64[5,4]) {
%p1.2 = s64[5]{0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=12}
%p0.1 = s64[4]{0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=13}
%dot.3 = s64[5,4]{1,0} dot(s64[5]{0} %p1.2, s64[4]{0} %p0.1), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
%transpose.4 = s64[5,4]{1,0} transpose(s64[5,4]{1,0} %dot.3), dimensions={0,1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
ROOT %tuple.5 = (s64[5,4]{1,0}) tuple(s64[5,4]{1,0} %transpose.4)
}
and I can confirm that xla::einsum
is being called. I guess I mistakenly think einsum
is a xlaop but it is actully a helper function which will call DotGeneral
which is optimal on XLA.
I tried another one
As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)
As = As.to('xla:0')
Bs = Bs.to('xla:0')
res4 = torch.einsum('bij,bjk->bik', As, Bs)
print(torch_xla._XLAC._get_xla_tensors_text([res4]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res4]))
and saw
IR {
%0 = f32[3,5,4]{2,1,0} xla::device_data(), location=<module>@test_einsum.py:23, device=CPU:0
%1 = f32[3,2,5]{2,1,0} xla::device_data(), location=<module>@test_einsum.py:22, device=CPU:0
%2 = f32[3,2,4]{2,1,0} aten::einsum(%1, %0), location=einsum@functional.py:364, equation=(bij,bjk->bik), ROOT=0
}
and
HloModule IrToHlo.6, entry_computation_layout={(f32[3,5,4]{2,1,0},f32[3,2,5]{2,1,0})->(f32[3,2,4]{2,1,0})}
ENTRY %IrToHlo.6 (p0.1: f32[3,5,4], p1.2: f32[3,2,5]) -> (f32[3,2,4]) {
%p1.2 = f32[3,2,5]{2,1,0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=22}
%p0.1 = f32[3,5,4]{2,1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_einsum.py" source_line=23}
%dot.3 = f32[3,2,4]{2,1,0} dot(f32[3,2,5]{2,1,0} %p1.2, f32[3,5,4]{2,1,0} %p0.1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
%transpose.4 = f32[3,2,4]{2,1,0} transpose(f32[3,2,4]{2,1,0} %dot.3), dimensions={0,1,2}, metadata={op_type="aten__einsum" op_name="aten__einsum" source_file="einsum@functional.py" source_line=364}
ROOT %tuple.5 = (f32[3,2,4]{2,1,0}) tuple(f32[3,2,4]{2,1,0} %transpose.4)
}
Which looks right too, I guess the only problematic one is 'ii->i'
We should enforce that xla::einsum
is called when appropriate in our unit tests, e.g. for ii->i
https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor.cpp#L3888
Is ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
sufficient to guarantee that the dispatch is correct? Or is there a more accurate check I can add to these unit tests?
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
won't be enough, even if we fall back in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065, we still increment the einsum
coutner because it happens in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1060
What you want to do is to add a counter in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065 using
XLA_COUNTER("EinsumFallback", 1);
and make sure fallback counter is not incremented for the test
Got it, I'll add that check and investigate this issue. Thanks for the context!
Closing this after https://github.com/pytorch/xla/pull/4027
Remaining einsum edge cases are tracked in https://github.com/pytorch/xla/issues/4032
🐛 Bug
with https://github.com/pytorch/xla/commit/aa73509890fb10b7f3bc2b9e8598f96382f2c979,
torch.einsum
should be dispatched toxla::einsum
, but it is notTo Reproduce
using nightly and
I saw
it seems like einsum incorrectly fallback in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L1065