To enable derivatives on gpu via pytorch, ops need to return handles to the input arguments. This is not needed the base/GpuFloatScalarOps case. We could instead create some c++ bindings that don't do this so we dont have to worry about cleaning up these handles in kotlin code.
To enable derivatives on gpu via pytorch, ops need to return handles to the input arguments. This is not needed the base/GpuFloatScalarOps case. We could instead create some c++ bindings that don't do this so we dont have to worry about cleaning up these handles in kotlin code.
Ops affected as of writing: Matmul, plus