Open Nullkooland opened 2 months ago
The commit is 5 years old, and thus unlikely to be reverted.
However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather
This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?
The commit is 5 years old, and thus unlikely to be reverted.
However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather
This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?
I don't think this is a stablehlo-to-tensor dialect conversion issue, since stablehlo.gather
is able to take i64
indices. After removing this force u32
indices conversion code in XLA, the same torch gather op will be exported as:
I don't see why do we need this u32
HACK.
This commit: https://github.com/openxla/xla/commit/a1c04b81ea9e2e23a362fdf5f7677800d2042f8f intrudoces the code that forces
xla::TorchGather
to convert indices tensor with element type ofi64
tou32
: https://github.com/openxla/xla/blob/a1c04b81ea9e2e23a362fdf5f7677800d2042f8f/xla/client/lib/slicing.cc#L148-L152This causes the StableHLO IRs exported with
torch_xla
to have such pattern aroundstablehlo.gather
:However, when lowering
stablehlo.gather
to MLIR tensor dialect'stensor.gather
, thetensor.gather
op requires that itsindices
operand tensor has signless integer type, soui32
would cause error. Also, it is inconsistent that any other index type in the IR isi64
while only index type of this gather isui32
.@blakehechtman Could you have a look? This commit looks like a HACK to me. Is it possible to revert this?