Open mcourteaux opened 3 years ago
The "right" way to do this is probably to support autodiff for custom defined extern functions. Unfortunately it has been on the todo list for a long time https://github.com/halide/Halide/issues/4118
I'm happy to help if anyone want to work on this
I'm currently porting my TensorFlow code to C++ with Halide. Once I get to the Gradient Descent part, I'll start reporting my findings.
I am looking into using Halide to rewrite my PhD software in. The gradients look very powerful, but I think I will run into numerical stability issues, due to non-optimized gradient expressions. One example I'm thinking about is the softmax function. TensorFlow defines a custom gradient for this function, which has proven to be an important improvement for the success of my work.
If I understood correctly what
Call
is in IR.h defines, I think these custom gradients should be inserted here: https://github.com/halide/Halide/blob/0b297f2944a0fe2076f8febdc1226796e5a13376/src/Derivative.cpp#L1150Although, I am not sure how to make sure it is seen as a C-ABI call, because the compiler might inline the call to my
softmax()
function. Looking at an example forsinh()
, shows that it's more complicated than just writing a function: https://github.com/halide/Halide/blob/72284a20f60aa2428e22db482feb27314e29e51b/src/IROperator.cpp#L2059-L2068 To me, it doesn't look like a C-ABI call, which might mean that the documentation of thisenum
is somewhat misleading?: https://github.com/halide/Halide/blob/0b297f2944a0fe2076f8febdc1226796e5a13376/src/IR.h#L466-L477 Misleading, as it seems that these function calls are being identified by name by the different backends and replaced by the function supported by the backend.So, all and all, I'm a bit lost, as I want to define
softmax()
as such (similar to this, but adapted for numerical stability to this):This is a numerically very stable implementation. Now it still needs a stable expression for gradients of it. Wrapping it in a
Call::make()
seems possible using the function-pointer variant ofCall::make()
: https://github.com/halide/Halide/blob/0b297f2944a0fe2076f8febdc1226796e5a13376/src/IR.h#L617-L619 But then associating a gradient to it still needs a hook inDerivative.cpp
. Next, I'm also confused with theFunc
vsExpr
phenomenon. Softmax should return aFunc
, as the sample code I copied and modified above does, but then wrapping it in aCall
to be able to hook gradients to it, expects anExpr
instead. So, I'm not sure if all of the above is even applicable.