Open klepp0 opened 6 months ago
Hey @klepp0! First of all, once again I would like to apologize for not getting back to you earlier. The semester turned out quite busy, and I had to put all the projects on pause. But I do have some extra time now, so I would be happy to get back to business.
I sat down to review the code this evening, and the implementation you provided is very impressive. I am happy you dug into it and made a well functioning reverse mode autodiff.
However, I am thinking now whether going with operator overloading in Rust is generally the right idea. I am doubtful about this approach primarily due to the fact that the code legibility is traded for performance.
For example, with the current implementation there is a lot of copying and cloning of operands, which feels somewhat wasteful. On the other hand, however, if you try to implement the std::ops
traits for references (to avoid copying) the usage becomes ugly i.e.
let a = Variable::new(...);
let b = Variable::new(...);
println!("{:?}, &a + &b); // avoids copying and consuming of a and b, but looks weird.
Maybe it would be wiser to just implement a simple set of procedures that would always operate on a pair of values e.g.
fn add(a: &mut Value, b: &mut Value) -> Value {
Value {
value: a.value + b.value, der: a.der + b.der
}
}
fn mul(a: &mut Value, b: &mut Value) -> Value {
Value {
value: a.value * b.value, der: b.value * a.der + a.value * b.der
}
}
// etc ...
The above is, of course, an example for the tangent mode. However, this can also be extended to the reverse mode.
This should also allow to introduce tape in a more organic fashion, and hopefully remove the reliance of name
parameter of the Variable
for insertion into the tape.
So far we have only worked on forward mode autodiff. However, to implement backprop and build some simple machine learning application we'll need to implement some reverse mode autodiff.
For now I'd like create a new module to follow with this example.
The implementation requires a
Variable
and aTapeEntry
class:Additional some functions need to be implemented:
reset_tape
: reset gradient tapegrad
: calculate gradients