openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.67k stars 427 forks source link

[XLA:CLIENT] Why force using U32 indices type when converting torch gather op #15654

Open Nullkooland opened 2 months ago

Nullkooland commented 2 months ago

This commit: https://github.com/openxla/xla/commit/a1c04b81ea9e2e23a362fdf5f7677800d2042f8f intrudoces the code that forces xla::TorchGather to convert indices tensor with element type of i64 to u32: https://github.com/openxla/xla/blob/a1c04b81ea9e2e23a362fdf5f7677800d2042f8f/xla/client/lib/slicing.cc#L148-L152

This causes the StableHLO IRs exported with torch_xla to have such pattern around stablehlo.gather:

gather_ui32_indices

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather, the tensor.gather op requires that its indices operand tensor has signless integer type, so ui32 would cause error. Also, it is inconsistent that any other index type in the IR is i64 while only index type of this gather is ui32.

@blakehechtman Could you have a look? This commit looks like a HACK to me. Is it possible to revert this?

cheshire commented 2 months ago

The commit is 5 years old, and thus unlikely to be reverted.

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather

This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?

Nullkooland commented 2 months ago

The commit is 5 years old, and thus unlikely to be reverted.

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather

This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?

I don't think this is a stablehlo-to-tensor dialect conversion issue, since stablehlo.gather is able to take i64 indices. After removing this force u32 indices conversion code in XLA, the same torch gather op will be exported as:

stablehlo_gather_i64_indices

I don't see why do we need this u32 HACK.