microsoft / tensorflow-directml

Fork of TensorFlow accelerated by DirectML
Apache License 2.0
454 stars 32 forks source link

Fix ArgMin and ArgMax crash when output type is int16 or uint16 #392

Closed PatriceVignola closed 1 year ago

PatriceVignola commented 1 year ago

TensorFlow supports int16 and uint16 as the output type of ArgMin and ArgMax, but DirectML only supports int32/int64/uint32/uint64. Therefore, we need to call DirectML with an int32 type and then cast back to int16 when the operation is done.