JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
253 stars 61 forks source link

Explicit Rules for Higher Order "Adjoints" #67

Open jessebett opened 5 years ago

jessebett commented 5 years ago

Can the API for adding an adjoint rule allow for explicitly specifying the rule for a higher order adjoint?

e.g. D(sin,x) = v -> v * cos(x) but I also know that D(sin,x; n=2) = (v1,v2) -> v2*v1*(- sin(x))

willtebbutt commented 5 years ago

Hmmm I don't think that we currently have that interface. I would be very interested to know what you think it should look like though, preferably with reference to the existing API.

oxinabox commented 4 years ago

So I think I worked out the API this should have. Its basically frule, (cf #74 ) except instead of giving pertubations (input sensitivites) as a single differential, you pass them in as a series of Tensor Coefficients (to use Griewank's term) of length N and the taylor rule returns the output, and a series of tensor coefficients also of length N which is the pertubation having been pushed forward (output sensitivities).

We don't take N directly -- we get it implictly from what ever we are pushing forward.

jessebett commented 4 years ago

@oxinabox this is exactly what we've done with jax.jet. Though, I'm not sure if our choice to pass around Tensor Coefficients instead of scaled Taylor Coefficients (again, Griewank's distinction) was actually a good choice. Something else to consider.

oxinabox commented 4 years ago

Originally posted by @shashi in https://github.com/JuliaDiff/ChainRulesCore.jl/issues/74#issuecomment-575364105

It turns out we need change #88 to be of the form:

res = frule(f, x..., partials...)
if res !== nothing
    fx, pushforward = res
    partials = pushforward(Zero(), partials...)
end

We're starting to think about Taylor mode FD where we need to differentiate through pushforward. If we don't have pushforward as a separate function, then we'd have to differentiate a call to frule which also re-runs the primal computation.

oxinabox commented 4 years ago

I don't think thats right. What I think we need to do is to define frules (or perhaps call them instead taylorrules) that give you the taylor series of the right dimension. And if they do that via defining some pushforward functiom that they then call AD on (cf #68) then thats fine.

Though I suspect for most frules needed they have well known taylor series that we could write out much more efficiently.

Or perhaps that we would like to get using symbolic AD.

oxinabox commented 4 years ago

Paraphrasing the second part of my post in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/102#issuecomment-575764876

We can't generate higher order frules from lower order frules (either with fused or unfused pushforwards) via AD in because it runs into the same problems that recursive forward mode runs into in the first place. The problems that Taylor mode wants to avoid. Or that at very least need to be carefully programed around.

Because all functions we want to write frules for call other functions themselves, any frule we write for them implictly (or explictitly) invokes the chain rule. The nth order generaliaztion of that is Faa di Bruno's formula which is from the 1800s.

The nth order deriviative of f(g(x)) needs the a bunch of different combinations of intermeidate values that have already been computed when taking the n-1th derivative (or earlier). So this approach robs us of that possible efficiency.

Here are the Faa di Bruno formula for derivative of f(g(x))

* 0th: `f(g(x))`

* 1st:  `g'(x) f'(g(x))`
  - reuses `g(x)`

* 2nd:  `g'(x)^2 f''(g(x)) + g''(x) f'(g(x))`
  - reuses `g(x)`, `g'(x)` and `f'(g(x))`

* 3rd:  `f'''(g(x)) g'(x)^3 + 3 g'(x) g''(x) f''(g(x)) + g'''(x) f'(g(x))
 - reuses: `g(x)`, `g'(x)`, `g'(x)^2`, `g''(x)`, `f''(g(x))` and`f'(g(x))`

So naive attempts at using AD to generate the rules via recursive call will fail to put us in a position to easily reuse the values. If we use symbolic AD to do all at once we might be better off, if ModellingToolKit can do very agressive CSE elimination of the final product.

oxinabox commented 4 years ago

@jessebett for Taylor mode, does one even want the nth order adjoint? Or does one want the nth-taylor coefficient? Given that those stop being equal for n>2

I am leaning towards us leaving frule as it is (for first order only), and then adding trule for taylor mode rules. Then frule can fall back to trule on a single term. So providing the trule gives you the frule,

That way in taylor mode you either hit good trules that give you exactly what you want, or you don't and you just do the normal thing of doing the computation on the polynomial. Since for higher orders that computation will be more efficient than e.g. a AD'd frule anyway.