tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
182.91k stars 73.92k forks source link

Introduce ReplaceWhileOperandShape api to HloInstruction. This function provides an interface to change the shape an tuple operand in while loops. #67127

Open copybara-service[bot] opened 1 week ago

copybara-service[bot] commented 1 week ago

Introduce ReplaceWhileOperandShape api to HloInstruction. This function provides an interface to change the shape an tuple operand in while loops.

The default implementation (ShapeTransformer) propagates shape change through instructions in which the output shape is directly inferred from operands, namely, gte, tuple, and nested while loops.

As the result, if the changed operand of the while loop is only used by the mentioned instructions, the call to ReplaceWhileOperandShape guarantees validity of the hlo graph after shape replacement. Currently, the replace function simply bails and returns false if there are any users other than the mentioned instructions.