compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
135 stars 7 forks source link

`Forward`-mode S2S autograd compiler to derive rules for some low-level functions #337

Open yebai opened 2 weeks ago

yebai commented 2 weeks ago

Due to their lack of backward pass, forward-mode autograd often has considerably different implementation properties than reverse-mode autograd. Given its different performance tradeoffs, I wonder whether forward-mode transformation could be more friendly for autograd compilers than reverse-mode (like Mooncake/Zygote), or at least compensate for some extreme cases of reverse-mode autograd.

For example, the sum_1000 example is a vector-input scalar-output function, which is a perfect example for forward-mode autograd but likely hard (at least requiring significantly more compiler optimization efforts) for the reverse-mode compiler to work well. This advantage goes further if we have chunk-mode forward-mode autograd.

I am talking about the source-to-source approach for both forward- and reverse-mode autograd implementations.

willtebbutt commented 1 week ago

For example, the sum_1000 example is a vector-input scalar-output function, which is a perfect example for forward-mode autograd but likely hard (at least requiring significantly more compiler optimization efforts) for the reverse-mode compiler to work well. This advantage goes further if we have chunk-mode forward-mode autograd.

I don't believe that this is correct. The fact that this is a many-input single-output function means that it is precisely the wrong kind of function to target with forwards-mode AD, no?

Due to their lack of backward pass, forward-mode autograd often has considerably different implementation properties than reverse-mode autograd. Given its different performance tradeoffs, I wonder whether forward-mode transformation could be more friendly for autograd compilers than reverse-mode (like Mooncake/Zygote), or at least compensate for some extreme cases of reverse-mode autograd.

I do agree that it's considerably easier to produce a high-quality source-to-source forwards-mode AD than it is a high-quality reverse-mode AD. I think it's something we should consider doing at some point, but I don't think it's something to do right now.

yebai commented 1 week ago

I don't believe that this is correct. The fact that this is a many-input single-output function means that it is precisely the wrong kind of function to target with forwards-mode AD, no?

You are correct here. I was thinking that chunk-mode forward mode could handle a large number of small-dimensional vector-input scalar/vector-output functions. This is still helpful but won't address the general problem.

I think it's something we should consider doing at some point, but I don't think it's something to do right now.

I wrote the issue to start some discussions.

EDIT: another interesting question is whether the forward-mode autograd compiler allows us to handle more Julia language features. For example, the try-catch-end block discovered recently in https://github.com/compintell/Mooncake.jl/issues/326#issue-2624603524

willtebbutt commented 1 week ago

EDIT: another interesting question is whether the forward-mode autograd compiler allows us to handle more Julia language features. For example, the try-catch-end block discovered recently in https://github.com/compintell/Mooncake.jl/issues/326#issue-2624603524

It's possible that it would, but I think there's probably more that can be done of the reverse-mode side of things to extend our current functionality to support try-catch-end blocks which contain Upsilon / PhiCNodes which don't wind up throwing, as is the case in the linked example, and another that was mentioned in https://github.com/compintell/Mooncake.jl/issues/31#issuecomment-2297801898 .

I wrote the issue to start some discussions.

Fair enough. Certainly, I think it's true that having forwards-mode which composes with reverse-mode would be nice -- I would quite like be able to compute Hessian-vector products (and, by extension, Hessians) by doing forwards-mode over reverse-mode.

In terms of what would need to be done:

yebai commented 1 week ago

We would need rules. This is a bit tedious, but we'd get to import roughly the same amount of stuff from ChainRules as we do for reverse mode, so I don't anticipate this being too much of a pain.

Now that we can freely compose forward and reverse modes, e.g. forward over reverse or reverse over forward, is it possible to use the reverse-mode rules for forward-mode autograd here?

willtebbutt commented 1 week ago

In general no, because you would have to run the rule N times, where N is the dimension of the output for each forward diff call, so you would have terrible performance. In scalar cases you might be able to get away with it, but to be honest they're not the cases that are hard to write rules for anyway.

yebai commented 1 week ago

This paper might be of interest.

Decomposing reverse-mode automatic differentiation Roy Frostig, Matthew J. Johnson, Dougal Maclaurin, Adam Paszke, Alexey Radul

We decompose reverse-mode automatic differentiation into (forward-mode) linearization followed by transposition. Doing so isolates the essential difference between forward- and reverse-mode AD, and simplifies their joint implementation. In particular, once forward-mode AD rules are defined for every primitive operation in a source language, only linear primitives require an additional transposition rule in order to arrive at a complete reverse-mode AD implementation. This is how reverse-mode AD is written in JAX and Dex.

https://arxiv.org/abs/2105.09469

willtebbutt commented 1 week ago

I'm familiar with this paper. It's a great way to frame AD, and very nicely explains what's going on. Mooncake's docs essentially frame it in the same way, in fact I'm pretty sure that we reference their follow up paper, we just don't break the 1-1 mapping between linearisation and transposition (read: computing the Frechet derivative and finding its adjoint operator).

I've always been a bit sceptical about the claim that you need to implement many fewer "transpose" rules than you do reverse-rules, because there are surprisingly many linear operators in a language and I'm reasonably sure that you wouldn't decompose many of the more monolithic functions (e.g. cholesky factorisation) down into a seqence of simpler linear transformations at the linearisation step, but would in fact wind up with a single "linearised cholesky" operator.

That being said, I've also not dug into it in any great depth, so it might be worth me revisiting this.