tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.53k stars 422 forks source link

Feature Request: Partial Derivative #1942

Open loganbnielsen opened 3 months ago

loganbnielsen commented 3 months ago

Feature description

I would like to be able to take partial derivative of the neural network.

In PyTorch like this: https://stackoverflow.com/a/66709533/10969548

In TensorFlow like this: https://stackoverflow.com/a/65968334/10969548

Feature motivation

This feature is useful whenever your model explicitly uses partials in its objective function. (e.g. differential equation solvers)

antimora commented 3 months ago

Linking an existing ticket which was closed because additional information was missing:

https://github.com/tracel-ai/burn/issues/121

loganbnielsen commented 3 months ago

Is the missing information about what a mixed partial derivative is? Maybe we can work with a pretty simple example:

f(x,y) = x^2 + 3y + xy

Then the cross partial would be 1. (you take the partial with respect to x or y and then the partial w.r.t to the other)

Using @nathanielsimard code from #121 the cross partial would be the same as:

fn run<B: Backend>() {
    let a = ADTensor::<B, 2>::random(...);
    let b = ADTensor::<B, 2>::random(...);
    let y = some_function(&a, &b);

    let grads = y.backward();

    let grad_a = grads.wrt(&a); // d some_function / da
    let grad_b = grads.wrt(&b); // d some_function / db

   // extension of provided code
   grad_ab = grad_a.wrt(&b); 
   grad_ba = grad_b.wrt(&a);

  // grad_ab == grad_ba -- Young's Theorem: https://en.wikipedia.org/wiki/Symmetry_of_second_derivatives
}

(I don't know if the new lines I added are legal code, I'm haven't done much with Burn yet. Presently doing the burn book MNIST classification example.)

I'm not sure about the details for how this is implemented efficiently in Tensorflow or Pytorch. Is this something I should do some research into? Or how can I be helpful?

loganbnielsen commented 2 months ago

@nathanielsimard could you provide the docs to wrt that you referenced in #121? For some reason having a hard time finding this method. There may have been some API changes since the post since since ADTensor doesn't appear to be a type anymore either.

nathanielsimard commented 2 months ago

@loganbnielsen They are now:

let mut grads = loss.backward(); // Compute the gradients.
let x_d = x.grad(&grads); // Can retrieve multiple times.
let x_d = x.grad_remove(&mut grads); // Can retrieve only one time.
loganbnielsen commented 1 month ago

@nathanielsimard I've spent some time reading about reverse accumulation, and I think a good starting point for me would be implementing second-order derivatives. It seems like we might either get cross partial derivatives for free or be close to it.

Has the addition of higher-order gradients (e.g., Hessian) been discussed anywhere, or how should I go about initiating a discussion on the best way to implement this? Would it be helpful for me to create a minimal example from scratch to demonstrate how it can be computed?

From a high-level perspective, I think the implementation would involve constructing a graph that keeps track of the computations done during the backward pass so that backward() can then be called on that graph.

One concern is that this approach might take up a lot of space, as it would compute the full Hessian, while users might only need specific elements from the higher-order derivatives. In PyTorch and TensorFlow, you can specify which variable(s) you want to differentiate, whereas calling backward() traverses all paths, as I understand it so far.

loganbnielsen commented 1 month ago

Wrote a simple autodiff in case it's helpful. It can do higher order derivatives and cross derivatives for + and *. Here's a link.

I think analogous in Burn world would be keeping the history of the steps for a given gradient in the backwards pass.

I think forwards pass autodiff is used in pytorch and tensorflow for higher order derivatives since there's only a couple grads of interest. Maybe I'll start looking into that next

nathanielsimard commented 1 month ago

@loganbnielsen

Since we currently perform autodiff using a backend decorator, there is nothing that prevent us from using another level of backend decorator to perform second order derivative. Some operations don't support second order derivative (backward operations), but most of them are actually implemented using other operations. Each level can have its own graph, and since we support gradient checkpointing, we can reduce memory usage quite a lot.