AnyDSL / thorin

The Higher-Order Intermediate Representation
https://anydsl.github.io
GNU Lesser General Public License v3.0
151 stars 15 forks source link

Simple gradient generation #101

Closed danielspaniol closed 4 years ago

danielspaniol commented 4 years ago

Simple gradient generation

This pull-request enables some simple automatic differentiation for thorin programs.

Overview

The most important new feature is the op_grad axiom in World. Lambdas can be passed to op_grad so that a new pass GradGen will replace it by a lambda that calculates the gradients.
Currently this only works for functions taking some reals and returning one real. There are no ifs or loops or function calls allowed yet, this will follow later. Also note that without -Othorin the grad axiom will stay unchanged in the thorin code.
The implementation is mainly based on this paper about AD for SSA form programs.

Example

Here is an example (from the mentioned paper) that already works:

fn example(a: f64, b: f64) -> f64 { a / (a + b * b) }

fn main() -> i32 {
    // If the grad-gen works, we get da=1/9 and db=4/9
    let (da, db) = grad(example)(-10.0, 5.0);
    (db / da) as i32
}

The generated code for the gradients of example is:

∇example_4999: cn [mem, r64, r64, cn [mem, «2∷◦; r64»]]
    _5031: [mem, r64, r64, cn [mem, «2∷◦; r64»]] = param ∇example_4999
    ∇op_5490: «2∷◦; r64» = B÷_5186 1∷r64
    return_5994: cn [mem, «2∷◦; r64»] = extract _5031, 3₄
    _5999: mem = extract _5031, 0₄
    b_5080: r64 = extract _5031, 2₄
    ∂a_5619: r64 = extract ∇op_5490, 0₂
    ∂_5516: r64 = extract ∇op_5490, 1₂
    ∇op_5541: «2∷◦; r64» = B⁺_5130 ∂_5516
    ∂a_5567: r64 = extract ∇op_5541, 0₂
    ∂_5841: r64 = extract ∇op_5541, 1₂
    _5657: «2∷◦; r64» = (∂a_5567, ∂a_5619)
    ∇op_5866: «2∷◦; r64» = B×_5690 ∂_5841
    ∂a_5658: r64 = ROp_add ‹2∷◦; 64∷nat› _5657
    ∂b_5956: r64 = ROp_add ‹2∷◦; 64∷nat› ∇op_5866
    _6025: «2∷◦; r64» = (∂a_5658, ∂b_5956)
    result_6048: [mem, «2∷◦; r64»] = (_5999, _6025)
    _6049: ⊥∷★ = return_5994 result_6048

    B÷_5186: Πr64 -> «2∷◦; r64»
        _5092: «2∷◦; r64» = ‹2∷◦; b_5080›
        a_5048: r64 = extract _5031, 1₄
        ∂f_5206: r64 = param B÷_5186
        _5357: «2∷◦; r64» = (a_5048, ∂f_5206)
        ∂₀_5362: r64 = ROp_mul ‹2∷◦; 64∷nat› _5357
        _5099: r64 = ROp_mul (0∷nat, 64∷nat) _5092
        _5100: «2∷◦; r64» = (a_5048, _5099)
        _5372: «2∷◦; r64» = (-1∷r64, ∂₀_5362)
        _5105: r64 = ROp_add (0∷nat, 64∷nat) _5100
        ∂₀_5379: r64 = ROp_mul ‹2∷◦; 64∷nat› _5372
        _5285: «2∷◦; r64» = (∂f_5206, _5105)
        _5394: «2∷◦; r64» = ‹2∷◦; _5105›
        ∂a_5308: r64 = ROp_div ‹2∷◦; 64∷nat› _5285
        ∂₁_5422: r64 = ROp_mul ‹2∷◦; 64∷nat› _5394
        _5437: «2∷◦; r64» = (∂₀_5379, ∂₁_5422)
        ∂_5459: r64 = ROp_div ‹2∷◦; 64∷nat› _5437
        _5462: «2∷◦; r64» = (∂a_5308, ∂_5459)

    B×_5690: Πr64 -> «2∷◦; r64»
        ∂f_5710: r64 = param B×_5690
        _5747: «2∷◦; r64» = (b_5080, ∂f_5710)
        ∂b_5748: r64 = ROp_mul ‹2∷◦; 64∷nat› _5747
        _5790: «2∷◦; r64» = ‹2∷◦; ∂b_5748›

And in fact the inliner is able to generate this code for main (notice how db/da is replaced by 4):

main_7691: cn [mem, cn [mem, s32]]
    _7707: [mem, cn [mem, s32]] = param main_7691
    mem_7712: mem = extract _7707, 0₂
    _7714: [mem, s32] = (mem_7712, 4∷s32)
    _7818: ⊥∷★ = return_7800 _7714

    return_7800: cn [mem, s32]
        return_7710: cn [mem, s32] = extract _7707, 1₂
        _7801: [mem, s32] = param return_7800
        return_7802: ⊥∷★ = return_7710 _7801

Future features

This merge-request only contains the most basic AD functionality. Features I plan to implement next are: