Open jessebett opened 5 years ago
Hmmm I don't think that we currently have that interface. I would be very interested to know what you think it should look like though, preferably with reference to the existing API.
So I think I worked out the API this should have.
Its basically frule
, (cf #74 )
except instead of giving pertubations (input sensitivites) as a single differential,
you pass them in as a series of Tensor Coefficients (to use Griewank's term) of length N
and the taylor rule returns the output, and a series of tensor coefficients also of length N
which is the pertubation having been pushed forward (output sensitivities).
We don't take N
directly -- we get it implictly from what ever we are pushing forward.
@oxinabox this is exactly what we've done with jax.jet
. Though, I'm not sure if our choice to pass around Tensor Coefficients instead of scaled Taylor Coefficients (again, Griewank's distinction) was actually a good choice. Something else to consider.
Originally posted by @shashi in https://github.com/JuliaDiff/ChainRulesCore.jl/issues/74#issuecomment-575364105
It turns out we need change #88 to be of the form:
res = frule(f, x..., partials...)
if res !== nothing
fx, pushforward = res
partials = pushforward(Zero(), partials...)
end
We're starting to think about Taylor mode FD where we need to differentiate through pushforward. If we don't have pushforward as a separate function, then we'd have to differentiate a call to frule
which also re-runs the primal computation.
I don't think thats right.
What I think we need to do is to define frules
(or perhaps call them instead taylorrules
)
that give you the taylor series of the right dimension.
And if they do that via defining some pushforward
functiom that they then call AD on
(cf #68) then thats fine.
Though I suspect for most frules
needed they have well known taylor series that we could write out much more efficiently.
Or perhaps that we would like to get using symbolic AD.
Paraphrasing the second part of my post in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/102#issuecomment-575764876
We can't generate higher order frules
from lower order frules (either with fused or unfused pushforwards) via AD in because it runs into the same problems that recursive forward mode runs into in the first place.
The problems that Taylor mode wants to avoid.
Or that at very least need to be carefully programed around.
Because all functions we want to write frules
for call other functions themselves, any frule
we write for them implictly (or explictitly) invokes the chain rule.
The nth order generaliaztion of that is Faa di Bruno's formula which is from the 1800s.
The nth
order deriviative of f(g(x))
needs the a bunch of different combinations of intermeidate values that have already been computed when taking the n-1
th derivative (or earlier).
So this approach robs us of that possible efficiency.
Here are the Faa di Bruno formula for derivative of
f(g(x))
* 0th: `f(g(x))` * 1st: `g'(x) f'(g(x))` - reuses `g(x)` * 2nd: `g'(x)^2 f''(g(x)) + g''(x) f'(g(x))` - reuses `g(x)`, `g'(x)` and `f'(g(x))` * 3rd: `f'''(g(x)) g'(x)^3 + 3 g'(x) g''(x) f''(g(x)) + g'''(x) f'(g(x)) - reuses: `g(x)`, `g'(x)`, `g'(x)^2`, `g''(x)`, `f''(g(x))` and`f'(g(x))`
So naive attempts at using AD to generate the rules via recursive call will fail to put us in a position to easily reuse the values. If we use symbolic AD to do all at once we might be better off, if ModellingToolKit can do very agressive CSE elimination of the final product.
@jessebett for Taylor mode, does one even want the nth order adjoint? Or does one want the nth-taylor coefficient? Given that those stop being equal for n>2
I am leaning towards us leaving frule
as it is (for first order only), and then adding trule
for taylor mode rules.
Then frule
can fall back to trule
on a single term.
So providing the trule
gives you the frule
,
That way in taylor mode you either hit good trule
s that give you exactly what you want,
or you don't and you just do the normal thing of doing the computation on the polynomial.
Since for higher orders that computation will be more efficient than e.g. a AD'd frule
anyway.
Can the API for adding an adjoint rule allow for explicitly specifying the rule for a higher order adjoint?
e.g.
D(sin,x) = v -> v * cos(x)
but I also know thatD(sin,x; n=2) = (v1,v2) -> v2*v1*(- sin(x))