ROCm / torch_migraphx

Libraries integrating migraphx with pytorch
BSD 3-Clause "New" or "Revised" License
5 stars 1 forks source link

propagate "bool_output" flag through shape ops #172

Closed shivadbhavsar closed 1 month ago

shivadbhavsar commented 1 month ago

Comparison ops in MIGraphX do not output bool_type (it keeps the input datatype). Eg. the greater op in migraphx will output float_type is the inputs are float_type.

This is problematic when such an op is an output node. Ie. The migraphx output will be float_type in the above case, where torch-migraphx expects bool_type. We have a bool_output flag that is set by ops that are affected by this problem. However, if there are any shape ops (eg. reshape, transpose, etc), this flag isnt propagated currently.

See added test case: a bool op (eg. torch.Tensor.gt) followed by a shape op fails without this change.