Comparison ops in MIGraphX do not output bool_type (it keeps the input datatype). Eg. the greater op in migraphx will output float_type is the inputs are float_type.
This is problematic when such an op is an output node. Ie. The migraphx output will be float_type in the above case, where torch-migraphx expects bool_type. We have a bool_output flag that is set by ops that are affected by this problem. However, if there are any shape ops (eg. reshape, transpose, etc), this flag isnt propagated currently.
See added test case: a bool op (eg. torch.Tensor.gt) followed by a shape op fails without this change.
Comparison ops in MIGraphX do not output
bool_type
(it keeps the input datatype). Eg. thegreater
op in migraphx will outputfloat_type
is the inputs arefloat_type
.This is problematic when such an op is an output node. Ie. The migraphx output will be
float_type
in the above case, where torch-migraphx expectsbool_type
. We have abool_output
flag that is set by ops that are affected by this problem. However, if there are any shape ops (eg. reshape, transpose, etc), this flag isnt propagated currently.See added test case: a bool op (eg. torch.Tensor.gt) followed by a shape op fails without this change.