openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.71k stars 436 forks source link

Fix implicit index handling in ScatterDeterminismExpander #19656

Closed sergey-kozub closed 12 hours ago

sergey-kozub commented 1 day ago

This PR fixes a bug related to handling missing (implied) indices and adds the corresponding tests.

  1. When scatter_dims_to_operand_dims size is not equal to the operand rank, the out_of_bound_tensor has incorrect dimensions, resulting in mismatched shapes of the select op. This is fixed at line 718.
  2. When the update is not scalar, the indices are recalculated - this requires updating the out_of_bound_tensor (lines 757-761).
  3. After expanding the indices, the has_scalar_indices flag has to be updated (line 777).

Also added a few cosmetic changes:

  1. Removed is_one_dimensional branch in ExpandIndices, as this never happens (probably an artefact from prior implementation).
  2. Broadcast the boundary constants instead of generating a (possibly big) literal.