diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.37k stars 164 forks source link

Supporting algorithmic differentiation (AD) in the Futhark language and compiler #1249

Closed athas closed 2 years ago

athas commented 3 years ago

We (meaning Cosmin, Robert, and myself) have begun working seriously on AD in Futhark. Currently the work is being done in the branches troels-ad and cosmin-ad. Our basic approach is conventional (personally I am going by the book Evaluating Derivatives, and its nomenclature will generally be used in our implementation). Apart from straightforward implementations of forward- and reverse-mode AD done on already optimised core language code, we investigate two novel areas:

I suspect we will not support AD for every possible program, because a fully general tape is difficult to implement in our context (for the same reasons we don't support extremely irregular GPU code). But we'll see; it's not an insurmountable challenge, I just suspect the cost/benefit ratio becomes lopsided after a certain point. But we should be able to handle every program that has only regular nested parallelism.

Source language representation

The current prototypes just apply AD to all entry points, which is clearly not useful. We need to figure out how we expose these facilities to the user. There is a subtle constraint in our design freedom, which is that whatever we come up with here must also be translatable to the essentially first-order core language (with a few ad-hoc second-order constructs). In particular, my draft does not directly expose a type of linear maps that support transposition; instead they are merely black-box functions.

So here's a design. A function jvp for Jacobian-vector product and vjp for vector-Jacobian product. The former is internally forward mode AD, and the latter is reverse mode.

val jvp : (f: a -> b) -> (x: a) -> (b, a -> b)
val vjp : (f: a -> b) -> (x: a) -> (b, b -> a)

Note that both return a pair of the primal part (b) as well as a function that is essentially a linear map (or its transpose) at x. An important (but technically not critical) feature is that we should be able to apply the linear map multiple times, without having to recompute the application of f to x. This is semantically trivial but tricky in a language without first class functions, which brings us to...

Core language representation

The core language supports some second-order operations, but it does not permit curried functions like jvp/vjp. This means that jvp (which I'll use as the sole example) needs to be split into parts, probably at least partially during defunctionalisation (which happens long before AD). Source-level code of the form

let (y, f') = jvp f x
let g = f' t

must be turned into something like

let (y, f_closure) = jvp f x
let g = f_lifted f_closure t

where f_closure is of some non-function type. But it's not clear what that type should be, and maybe it cannot be determined at this early stage, since it depends crucially on how AD decides to transform f. Maybe this really cannot be done without extending the core language with some placeholder linear map type that will be replaced by the AD pass. Or maybe I'm just too tired to design a solution tonight.

A compromise API

These source-level functions are not as elegant, but are simpler to implement, as they don't have this bothersome currying:

val jvp : (f: a -> b) -> (x: a) -> (x': a) -> (b, b)
val vjp : (f: a -> b) -> (x: a) -> (y': b) -> (b, a)

And when you don't actually use the primal part, then dead code elimination can easily get rid of it.

athas commented 3 years ago

After discussion with Cosmin and Martin, I've decided to implement the simplest possible interface - what Barak Pearlmutter calls his ideal API. It doesn't produce the primal part at all:

val jvp : (f: a -> b) -> (x: a) -> (x': a) -> b
val vjp : (f: a -> b) -> (x: a) -> (y': b) -> a

(The last function arrow is really a linear map here.)

athas commented 3 years ago

This is now implemented. It works great. Test program:

let f x = f64.cos x * f64.sin x

entry fwd x = jvp f x 1

entry rev x = vjp f x 1
$ futhark c test.fut
$ echo 10 | ./test -e fwd
0.408082f64
$ echo 10 | ./test -e rev
0.408082f64
athas commented 3 years ago

One major limitation is that jvp and vjp do not work in the interpreter, and it would be a major project to support them. I can't justify spending that amount of time, so they will likely be compiler-only for the foreseeable future. It might be fun to do a dynamic/JIT-ish AD implementation for the interpreter, but realistically it is so slow that I don't know if it's even worth being smart.

athas commented 3 years ago

The forward-mode AD implementation is already powerful enough to run this raymarcher, which I previously implemented with dual numbers in library code. Performance remains the same, but it's much less code with jvp.

flip111 commented 3 years ago

for the ones not really into depth with these things, could you leave a quick comment why AD is such a cool feature?

athas commented 3 years ago

I'm actually not an expert; I just think it's cool tech and a fun challenge to implement. It's useful whenever you want to differentiate a function that would be impossibly awkward to differentiate by hand. Two interesting applications:

  1. The obvious one is function optimisation where you have a cost function and you want to change some of the inputs of the cost function to minimise its result. Essentially a fancy argmax. You can use gradient descent to do this. Most modern machine learning is built on backprop, but I'm personally not all that interested in machine learning. Lots of other interesting uses of cost functions exist.

  2. As an example, consider a ray marcher that makes use of signed distance functions. In such a program, we represent objects by functions that compute the distance between some point in the space and the nearest point on the surface of the object. This lets you represent cool fractal-like surfaces that would be impossible to represent with conventional discrete geometric primitives. The downside is that when your object is defined by a signed distance function, then computing surface normals (as you will need to handle e.g. reflection) becomes difficult. However, if you have a distance function, then the normal function is just the distance function differentiated with respect to a position in space. It's very nice that you can write any kind of crazy distance function and get the surface normals for free.

athas commented 3 years ago

Scalar reverse mode now works well enough to handle signed distance functions. Writing this kind of code is really cool:

let calcNormal (obj: Object) (pos: Position) : Direction =
  geometry.normalize (vjp (\p -> geometry.sdObject p obj) pos 1)
maedoc commented 3 years ago

This looks like fun; will there be a way to define the vjp for primitives? For instance, the vjp of the FFT is the iFFT of the rescaled conjugate gradient (cf Autograd's defintion), so there'd be no need to do AD on the FFT algorithm itself.

athas commented 3 years ago

We know it will be necessary to support that, but we haven't put in any design or implementation effort yet. I expect it'll be done via attributes, where you can annotate functions with the name of a hand-written derivative.

maedoc commented 2 years ago

annotate functions with the name of a hand-written derivative.

It would also be helpful to label a function as subject to jvp instead of vjp such that a call graph which is diff'ed with vjp still uses forward mode for functions annotated with jvp, because "Optimal Jacobian accumulation is NP-complete" (according to Google) and a user may want to inform the compiler based on experience or benchmarking. I have some functions where benchmarking shows jvp is faster, but for the overall problem I need to use vjp because of memory pressure. It would also be a neat idea akin to the current autotuning or multicore load balancing functionality to automatically choose a particular jvp vs vjp approach depending on problem size or runtime.

henglein commented 2 years ago

Naumann's NP-completeness result (minimizing the total number of additions and multiplications to compute the Jacobian of an analytic function given as a straight-line program [= term with sharing by let-expressions = computation graph]) expresses that there may be useful computations of partial sums-of-products in a computation graph that can be used multiple times; it is under the additional assumption that identities between partial derivatives of different intermediate results have already been discovered and can be exploited. It is equally applicable to forward mode and reverse mode AD. In particular, neither forward-mode AD nor reverse-mode AD can yield the minimum number of arithmetic operations since they sweep the computation graph from the inputs or outputs, respectively. Putting it positively: There's lots of opportunity for hacking (new modes of AD, including locally flipping between forward-mode and reverse-mode AD) -- and there'll always be a way of doing it with slightly fewer operations in some cases.

Fritz Henglein @.*** +45-30589576

On Thu, Sep 30, 2021 at 10:15 PM marmaduke woodman @.***> wrote:

annotate functions with the name of a hand-written derivative.

It would also be helpful to label a function as subject to jvp instead of vjp such that a call graph which is diff'ed with vjp still uses forward mode for functions annotated with jvp, because "Optimal Jacobian accumulation is NP-complete" (according to Google) and a user may want to inform the compiler based on experience or benchmarking. I have some functions where benchmarking shows jvp is faster, but for the overall problem I need to use vjp because of memory pressure. It would also be a neat idea akin to the current autotuning or multicore load balancing functionality to automatically choose a particular jvp vs vjp approach depending on problem size or runtime.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://eur02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fdiku-dk%2Ffuthark%2Fissues%2F1249%23issuecomment-931632345&data=04%7C01%7Chenglein%40di.ku.dk%7Cac34fe8c9ce74442028a08d9844f00f1%7Ca3927f91cda14696af898c9f1ceffa91%7C0%7C0%7C637686297119869570%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=HvAUJUmSu3otqyiAjwulWvDOrnbHk0ZgLFM3dBYRTF8%3D&reserved=0, or unsubscribe https://eur02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAAFDZDZVO6CO2D4D7BIFH2DUETAMZANCNFSM4XN2POYQ&data=04%7C01%7Chenglein%40di.ku.dk%7Cac34fe8c9ce74442028a08d9844f00f1%7Ca3927f91cda14696af898c9f1ceffa91%7C0%7C0%7C637686297119879564%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=QLAxco1lEEsoPKhljDRi2s0dnB1cNF5C2fuuUXnZ2hI%3D&reserved=0 . Triage notifications on the go with GitHub Mobile for iOS https://eur02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fapps.apple.com%2Fapp%2Fapple-store%2Fid1477376905%3Fct%3Dnotification-email%26mt%3D8%26pt%3D524675&data=04%7C01%7Chenglein%40di.ku.dk%7Cac34fe8c9ce74442028a08d9844f00f1%7Ca3927f91cda14696af898c9f1ceffa91%7C0%7C0%7C637686297119879564%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=TT27urB2Wau3eHYBTdSzFUboA87h2kbhWs2N%2F1aftvI%3D&reserved=0 or Android https://eur02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fplay.google.com%2Fstore%2Fapps%2Fdetails%3Fid%3Dcom.github.android%26referrer%3Dutm_campaign%253Dnotification-email%2526utm_medium%253Demail%2526utm_source%253Dgithub&data=04%7C01%7Chenglein%40di.ku.dk%7Cac34fe8c9ce74442028a08d9844f00f1%7Ca3927f91cda14696af898c9f1ceffa91%7C0%7C0%7C637686297119889565%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=zM8XEBhTrND3Lb7WwMZRcCahX89vRN6IV2jr9sA3OUg%3D&reserved=0.

tjpalmer commented 2 years ago

What's the status on this?

athas commented 2 years ago

AD in general? Under active development in the clean-ad branch. Performance on the benchmarks we have ported (mostly ADBench) is very good. A lot of the language is supported, but not all of it. The main missing part is complex sequential looping in reverse mode, which is a particularly difficult problem when you want to generate allocation-free code. Our AD benchmarks are in this repository, but the results are not. The interim results for ADBench are here. The most interesting graphs are the "jacobian ÷ objective" ones, as they show the cost of computing the full Jacobian relative to computing the objective function. (Comparing absolute runtimes diminishes the AD part and mostly measures how good the compilers are at low-level code generation and optimisation, although Futhark compares very well here as well).

Here are some of the pertinent graphs:

image

image

image

image

tjpalmer commented 2 years ago

Thanks much for the status! Again, you're more thorough than I expect. Looking forward to it landing on main.

zfnmxt commented 2 years ago

We've put a paper detailing AD in Futhark on arXiv.