hidal00p / rad

AD in Rust
1 stars 1 forks source link

Introduce reverse mode autodiff #4

Open klepp0 opened 6 months ago

klepp0 commented 6 months ago

So far we have only worked on forward mode autodiff. However, to implement backprop and build some simple machine learning application we'll need to implement some reverse mode autodiff.

For now I'd like create a new module to follow with this example.

The implementation requires a Variable and a TapeEntry class:

classDiagram
    Variable ..> TapeEntry
    GradientTape ..> TapeEntry

    class Variable {
         +String name
         +f32 value
         +add(Variable) Variable
         +mul(Variable) Variable
    }
    class TapeEntry {
        +Vec[&Variable] inputs
        +Vec[&Variable] outputs
        +Fn propagate
    }
    class GradientTape {
        +Vec<TapeEntry>
        +add_entry(TapeEntry)
        +clear()
     }

Additional some functions need to be implemented:

[!NOTE] This is only a rough outline for now. I will refine the issue as the development proceeds.

hidal00p commented 3 months ago

Hey @klepp0! First of all, once again I would like to apologize for not getting back to you earlier. The semester turned out quite busy, and I had to put all the projects on pause. But I do have some extra time now, so I would be happy to get back to business.

I sat down to review the code this evening, and the implementation you provided is very impressive. I am happy you dug into it and made a well functioning reverse mode autodiff.

However, I am thinking now whether going with operator overloading in Rust is generally the right idea. I am doubtful about this approach primarily due to the fact that the code legibility is traded for performance. For example, with the current implementation there is a lot of copying and cloning of operands, which feels somewhat wasteful. On the other hand, however, if you try to implement the std::ops traits for references (to avoid copying) the usage becomes ugly i.e.

let a = Variable::new(...);
let b = Variable::new(...);
println!("{:?},  &a + &b); // avoids copying and consuming of a and b, but looks weird.

Maybe it would be wiser to just implement a simple set of procedures that would always operate on a pair of values e.g.

fn add(a: &mut Value, b: &mut Value) -> Value {
    Value {
        value: a.value + b.value, der: a.der + b.der
    }
}

fn mul(a: &mut Value, b: &mut Value) -> Value {
    Value {
        value: a.value * b.value, der: b.value * a.der + a.value * b.der
    }
}

// etc ...

The above is, of course, an example for the tangent mode. However, this can also be extended to the reverse mode. This should also allow to introduce tape in a more organic fashion, and hopefully remove the reliance of name parameter of the Variable for insertion into the tape.