tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
9.03k stars 449 forks source link

Feature Request: Partial Derivative #1942

Open loganbnielsen opened 5 months ago

loganbnielsen commented 5 months ago

Feature description

I would like to be able to take partial derivative of the neural network.

In PyTorch like this: https://stackoverflow.com/a/66709533/10969548

In TensorFlow like this: https://stackoverflow.com/a/65968334/10969548

Feature motivation

This feature is useful whenever your model explicitly uses partials in its objective function. (e.g. differential equation solvers)

antimora commented 5 months ago

Linking an existing ticket which was closed because additional information was missing:

https://github.com/tracel-ai/burn/issues/121

loganbnielsen commented 5 months ago

Is the missing information about what a mixed partial derivative is? Maybe we can work with a pretty simple example:

f(x,y) = x^2 + 3y + xy

Then the cross partial would be 1. (you take the partial with respect to x or y and then the partial w.r.t to the other)

Using @nathanielsimard code from #121 the cross partial would be the same as:

fn run<B: Backend>() {
    let a = ADTensor::<B, 2>::random(...);
    let b = ADTensor::<B, 2>::random(...);
    let y = some_function(&a, &b);

    let grads = y.backward();

    let grad_a = grads.wrt(&a); // d some_function / da
    let grad_b = grads.wrt(&b); // d some_function / db

   // extension of provided code
   grad_ab = grad_a.wrt(&b); 
   grad_ba = grad_b.wrt(&a);

  // grad_ab == grad_ba -- Young's Theorem: https://en.wikipedia.org/wiki/Symmetry_of_second_derivatives
}

(I don't know if the new lines I added are legal code, I'm haven't done much with Burn yet. Presently doing the burn book MNIST classification example.)

I'm not sure about the details for how this is implemented efficiently in Tensorflow or Pytorch. Is this something I should do some research into? Or how can I be helpful?

loganbnielsen commented 4 months ago

@nathanielsimard could you provide the docs to wrt that you referenced in #121? For some reason having a hard time finding this method. There may have been some API changes since the post since since ADTensor doesn't appear to be a type anymore either.

nathanielsimard commented 4 months ago

@loganbnielsen They are now:

let mut grads = loss.backward(); // Compute the gradients.
let x_d = x.grad(&grads); // Can retrieve multiple times.
let x_d = x.grad_remove(&mut grads); // Can retrieve only one time.
loganbnielsen commented 3 months ago

@nathanielsimard I've spent some time reading about reverse accumulation, and I think a good starting point for me would be implementing second-order derivatives. It seems like we might either get cross partial derivatives for free or be close to it.

Has the addition of higher-order gradients (e.g., Hessian) been discussed anywhere, or how should I go about initiating a discussion on the best way to implement this? Would it be helpful for me to create a minimal example from scratch to demonstrate how it can be computed?

From a high-level perspective, I think the implementation would involve constructing a graph that keeps track of the computations done during the backward pass so that backward() can then be called on that graph.

One concern is that this approach might take up a lot of space, as it would compute the full Hessian, while users might only need specific elements from the higher-order derivatives. In PyTorch and TensorFlow, you can specify which variable(s) you want to differentiate, whereas calling backward() traverses all paths, as I understand it so far.

loganbnielsen commented 3 months ago

Wrote a simple autodiff in case it's helpful. It can do higher order derivatives and cross derivatives for + and *. Here's a link.

I think analogous in Burn world would be keeping the history of the steps for a given gradient in the backwards pass.

I think forwards pass autodiff is used in pytorch and tensorflow for higher order derivatives since there's only a couple grads of interest. Maybe I'll start looking into that next

nathanielsimard commented 3 months ago

@loganbnielsen

Since we currently perform autodiff using a backend decorator, there is nothing that prevent us from using another level of backend decorator to perform second order derivative. Some operations don't support second order derivative (backward operations), but most of them are actually implemented using other operations. Each level can have its own graph, and since we support gradient checkpointing, we can reduce memory usage quite a lot.

loganbnielsen commented 1 month ago

Could you provide a couple code pointers and a dummy example of the API we're hoping to have?

nathanielsimard commented 1 month ago
type SecondOrderBackend = Autodiff<Autodiff<Wgpu>>`;

let loss = ...;
let gradients = loss.backward();

Does that make sense?

loganbnielsen commented 3 weeks ago

I think the Autodiff decorator needs to be aware of its level to map to which graph to register its operations. One way to do this would be like:

impl<B: Backend> Backend for Autodiff<B> {
    const GRAD_LEVEL: usize = B::GRAD_LEVEL + 1;
    // ...
}

But it'd require GRAD_LEVEL to be ~const~ defined on all backends.

How do you think the decorator should know which graph it belongs to?

loganbnielsen commented 2 weeks ago

And also what API do we want to retrieve the second order gradients? right now it's:

let gradients = loss.backward();
y.grad(&grads);

Which presently grabs first order derivatives. Do we want to try doing something like

~~y.grad(&y.grad(&grads))~~

For second order derivatives?

EDIT

Actually I think the syntax will end up being:

let grad_y = y.grad(&grads)
let grads_grad_y = grad_y.backward();
let grad_yy = grad_y.get(&grads_grad_y);
loganbnielsen commented 2 weeks ago

One more note, the second order derivative graph should only record operations done by the backend in the backward pass. This is different from the present bahavior that only records the forward pass

loganbnielsen commented 2 weeks ago

I think backend would need to be part of the executing the backend steps here as this is where the gradients are computed so that we could have an autodiff backend tracking the executions

    fn execute_steps(
        tape: Vec<Vec<StepBoxed>>,
        mut grads: Gradients,
        mut checkpointer: Checkpointer,
    ) -> Gradients {
        tape.into_iter().rev().for_each(|steps| {
            steps
                .into_iter()
                .for_each(|step| step.step(&mut grads, &mut checkpointer))
        });

        #[cfg(feature = "export_tests")]
        // For checkpointing tests
        assert!(checkpointer.is_empty());
        grads
    }
nathanielsimard commented 1 week ago

Actually I think the syntax will end up being:

let grad_y = y.grad(&grads)
let grads_grad_y = grad_y.backward();
let grad_yy = grad_y.get(&grads_grad_y);

Not sure about the syntax, but this is how I see it being implemented underneath.

loganbnielsen commented 1 week ago

What's an untracked ops step and why do we have them?

    fn trivial() {
        type Backend = Autodiff<Wgpu>;
        let device = WgpuDevice::default();
        let x: Tensor<Backend, 1> = Tensor::from_data([5.0], &device).require_grad();
        let grads = x.backward();
        let maybe_x_grad = x.grad(&grads);
        if let Some(x_grad) = maybe_x_grad {
            println!("{}", x_grad);
        } else {
            print!("No grad.");
        }
    }
---- test::trivial stdout ----
from_data B::GRAD_LEVEL 1
require grad: Tensor {
  data:
[5.0],
  shape:  [1],
  device:  DefaultDevice,
  backend:  "autodiff<fusion<jit<wgpu<wgsl>>>>",
  kind:  "Float",
  dtype:  "f32",
}
REGISTER on GRAD_LEVEL 0
BACKWARD on GRAD_LEVEL 0
EXISTING STEPS: [{
    NodeID {
        value: 0,
    }: RootStep {
        node: Node {
            parents: [],
            order: 0,
            id: NodeID {
                value: 0,
            },
            requirement: Grad,
            properties: Ambiguous,
            client: MutexClient,
        },
    },
    NodeID {
        value: 1,
    }: UntrackedOpsStep {
        ops: Ops {
            parents: [
                None,
            ],
            node: Node {
                parents: [],
                order: 1,
                id: NodeID {
                    value: 1,
                },
                requirement: None,
                properties: ComputeBound,
                client: MutexClient,
            },
            state: (),
        },
    },
}]
tape is: [[]]
Tensor {
  data:
[1.0],
  shape:  [1],
  device:  DefaultDevice,
  backend:  "fusion<jit<wgpu<wgsl>>>",
  kind:  "Float",
  dtype:  "f32",
}

successes:
    test::trivial
loganbnielsen commented 1 week ago

I'm also a bit confused how GRAD_LEVEL is 0 when I thought GRAD_LEVEL was > 0 for backend capable of autodiff.

diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs
index b1d597ad..d9b126dd 100644
--- a/crates/burn-autodiff/src/backend.rs
+++ b/crates/burn-autodiff/src/backend.rs
@@ -22,6 +22,8 @@ pub struct Autodiff<B, C = NoCheckpointing> {
 }

 impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
+    const GRAD_LEVEL: u8 = B::GRAD_LEVEL + 1;
+
     type Device = B::Device;

     type FullPrecisionBridge = AutodiffBridge<B::FullPrecisionBridge>;
diff --git a/crates/burn-autodiff/src/runtime/client.rs b/crates/burn-autodiff/src/runtime/client.rs
index 13c07c73..f0b69ec7 100644
--- a/crates/burn-autodiff/src/runtime/client.rs
+++ b/crates/burn-autodiff/src/runtime/client.rs
@@ -9,7 +9,7 @@ use burn_tensor::backend::Backend;
 /// Client used to communicate with the autodiff server.
 pub trait AutodiffClient: Send + Clone {
     /// Register a new step.
-    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
+    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder, grad_level: u8);
     /// Call backpropagation from the given tensor.
     fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
 }
diff --git a/crates/burn-autodiff/src/runtime/mspc.rs b/crates/burn-autodiff/src/runtime/mspc.rs
index c128f34b..9c2309e2 100644
--- a/crates/burn-autodiff/src/runtime/mspc.rs
+++ b/crates/burn-autodiff/src/runtime/mspc.rs
@@ -63,7 +63,7 @@ impl ChannelClient {
 }

 impl AutodiffClient for ChannelClient {
-    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
+    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder, grad_level: u8) {
         self.sender
             .send(Message::Register {
                 node_id,
diff --git a/crates/burn-autodiff/src/runtime/mutex.rs b/crates/burn-autodiff/src/runtime/mutex.rs
index 1968fcc3..d72f7070 100644
--- a/crates/burn-autodiff/src/runtime/mutex.rs
+++ b/crates/burn-autodiff/src/runtime/mutex.rs
@@ -17,10 +17,24 @@ impl core::fmt::Debug for MutexClient {
 }

 static SERVER: spin::Mutex<Option<AutodiffServer>> = spin::Mutex::new(None);
+static SERVER2: spin::Mutex<Option<AutodiffServer>> = spin::Mutex::new(None);

 impl AutodiffClient for MutexClient {
-    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
-        let mut server = SERVER.lock();
+    fn register(
+        &self,
+        node_id: NodeRefCount,
+        step: StepBoxed,
+        actions: CheckpointerBuilder,
+        grad_level: u8,
+    ) {
+        println!("REGISTER on GRAD_LEVEL {}", grad_level);
+        let mut server = if grad_level == 0 {
+            SERVER.lock()
+        } else if grad_level == 1 {
+            SERVER2.lock()
+        } else {
+            panic!("UNEXPECTED GRAD LEVEL")
+        };

         if let Some(server) = server.as_mut() {
             server.register(node_id, step, actions);
@@ -31,8 +45,16 @@ impl AutodiffClient for MutexClient {
         server_new.register(node_id, step, actions);
         *server = Some(server_new);
     }
+
     fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
-        let mut server = SERVER.lock();
+        println!("BACKWARD on GRAD_LEVEL {}", B::GRAD_LEVEL);
+        let mut server = if B::GRAD_LEVEL == 0 {
+            SERVER.lock()
+        } else if B::GRAD_LEVEL == 1 {
+            SERVER2.lock()
+        } else {
+            panic!("UNEXPECTED GRAD LEVEL {}", B::GRAD_LEVEL)
+        };
         let node_id = root.node.id;
         let grads = Gradients::new::<B>(root.node, root.primitive);

diff --git a/crates/burn-autodiff/src/runtime/server.rs b/crates/burn-autodiff/src/runtime/server.rs
index 11894aa7..0aa9a11b 100644
--- a/crates/burn-autodiff/src/runtime/server.rs
+++ b/crates/burn-autodiff/src/runtime/server.rs
@@ -16,7 +16,12 @@ pub struct AutodiffServer {
 }

 impl AutodiffServer {
-    pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
+    pub fn register(
+        &mut self,
+        rc: NodeRefCount,
+        step: StepBoxed,
+        actions: CheckpointerBuilder,
+    ) {
         let parents = step.parents();
         let node_id = *rc.as_ref();

@@ -27,6 +32,7 @@ impl AutodiffServer {
     }

     pub fn backward(&mut self, grads: Gradients, node_id: NodeID) -> Gradients {
+        println!("EXISTING STEPS: [{:#?}]", self.steps);
         let step = self.steps.remove(&node_id).expect(
             "Node should have a step registered, did you forget to call \
              `Tensor::register_grad` on the tensor where you need gradients?",
@@ -36,6 +42,7 @@ impl AutodiffServer {
         let (tape, builder) = self.build_tape(node_id, step, builder);
         let checkpointer = builder.build(&self.steps);

+        println!("tape is: [{:#?}]", tape);
         let gradients = Self::execute_steps(tape, grads, checkpointer);

         // Cleanup
diff --git a/crates/burn-autodiff/src/tensor.rs b/crates/burn-autodiff/src/tensor.rs
index b51d75d0..9fb3cc2a 100644
--- a/crates/burn-autodiff/src/tensor.rs
+++ b/crates/burn-autodiff/src/tensor.rs
@@ -147,6 +147,7 @@ impl<B: Backend> AutodiffTensor<B> {
             self.rc.clone(),
             Box::new(step_that_created_the_tensor),
             actions,
+            B::GRAD_LEVEL
         );
         self
     }
diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs
index 2eecb5c5..1cbf9611 100644
--- a/crates/burn-tensor/src/tensor/api/base.rs
+++ b/crates/burn-tensor/src/tensor/api/base.rs
@@ -927,6 +927,7 @@ where
     where
         T: Into<TensorData>,
     {
+        println!("from_data B::GRAD_LEVEL {}", B::GRAD_LEVEL);
         let data = data.into();
         check!(TensorCheck::creation_ops::<D>(
             "From Data",
diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs
index a4808336..c443b5e7 100644
--- a/crates/burn-tensor/src/tensor/backend/base.rs
+++ b/crates/burn-tensor/src/tensor/backend/base.rs
@@ -66,6 +66,9 @@ pub trait Backend:
     + core::fmt::Debug
     + 'static
 {
+    /// Level of gradient being tracked.
+    const GRAD_LEVEL: u8 = 0;
+
     /// Device type.
     type Device: DeviceOps;
nathanielsimard commented 1 week ago

What's an untracked ops step and why do we have them?

You are using from_data, there are no backward function for that!

Not sure for the other question 😅