Open PetroZarytskyi opened 3 days ago
This is a big change so having different opinions on the PR would be great. We discussed the idea with @vgvassilev. @vaithak I'd love to know your thoughts.
This looks really good 👍🏼 Thanks, @PetroZarytskyi, for improving this.
Currently, on master, we have two reverse diff modes:
DiffMode::reverse
for gradients andDiffMode::experimental_pullback
for pullbacks. In this PR, they are essentially merged intoDiffMode::reverse
. This has been achieved by placing a pullback in the gradient overload instead of a gradient function. Let's consider an example:-> On master:
In this PR:
Note: To make this system work with error estimation, I had to enable overloads there. To do that, I had to change the type of
_final_error
parameters fromdouble&
todouble*
.Advantages: 1) On master, we have 11 DiffModes, many of which use the same visitors. Having a unified reverse DiffMode makes the system easier to understand. 2) In RMV,
Derive
andDerivePullback
do almost the same job. This PR removesDerivePullback
completely. 3) With this PR, clad does not use overloads for the reverse mode anymore: just one gradient function and one pullback function. This is a great step towards supporting C, which does not have overloads. 4) Differentiating recursive functions used to generate both the gradient and the pullback. Now only the pullback is generated.Disadvantages: 1) Now gradient forward declaration is only supported with
void*
adjoint parameter types. e.g. for a functiondouble f(double a, double b)
, it doesn't make sense anymore to forward declarevoid f_grad(double a, double b, double *_d_a, double *_d_b)
. The optionsvoid f_grad(double a, double b, void *_d_a, void *_d_b)
andvoid f_pullback(double a, double b, double _d_y, void *_d_a, void *_d_b)
still work. However, forward declarations don't seem to be that widely used. For example, when we changed all array_ref adjoint types to pointers in the gradient signature, this didn't break a single ROOT test. The main way to execute derivatives (withCladFunction
) works as before. 2) Now all differentiated functions have the pullback_d_y
parameter. This may make it harder to understand the derivative code. Moreover, every time the function has a parameter namedy
, the pullback parameter will be renamed to_d_y0
to avoid name collisions. This could make the code even more confusing. However, we can fix the last problem by giving the pullback parameter a different name.