microsoft / tensorflow-directml-plugin

DirectML PluggableDevice plugin for TensorFlow 2
Apache License 2.0
179 stars 23 forks source link

Add an int32 kernel registration for Fill #316

Closed PatriceVignola closed 1 year ago

PatriceVignola commented 1 year ago

Unfortunately, the DEVICE_DEFAULT int32 registration for Fill in TensorFlow core is mistakenly trapped inside of a CUDA #ifdef, so we cannot leverage it. To work around this, we have to emulate it in our plugin like we did for Pack and StridedSlice.

By forcing this operator on DML but in host memory for int32, it forces some elementwise operators to be run on the GPU, which gives us a pretty significant performance improvement when running the model described here: https://github.com/microsoft/tensorflow-directml-plugin/discussions/315