tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.3k stars 403 forks source link

Retrieving gradients in different threads panics. #1142

Closed giucesar closed 7 months ago

giucesar commented 7 months ago

Hi.

If a tensor is created in the main thread, its gradients panic when the operations happen in a different thread.

Example:

use std::{thread::sleep, time::Duration};

use burn::{backend::Autodiff, tensor::Tensor};

type B = Autodiff<burn::backend::NdArray>;

fn obj(x: Tensor<B, 1>) -> Tensor<B, 1> {
    (x.clone() * x.clone()).sum().sqrt()
}

fn main() {
    let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0]).set_require_grad(true);
    let op = move || {
        let value = obj(x.clone()).sum();
        let grads = value.backward();
        let g_x = Tensor::<B, 1>::from_inner(x.clone().grad(&grads).unwrap());

        let x_1 = x.clone() - g_x;
        let value = obj(x_1.clone()).sum();
        println!("value: {}", value.into_scalar());
    };

    op.clone()();

    let thread_1 = std::thread::spawn(op.clone());
    // sleep(Duration::from_millis(100));
    let thread_2 = std::thread::spawn(op.clone());
    thread_1.join().unwrap();
    thread_2.join().unwrap();
}

If you run this small code several times, it fails most of the time in the line that retrieves the gradients of x. The error message is:

Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?

Interestingly, if you comment the thread_2, it works as expected. Also, inserting

sleep(Duration::from_millis(100));

between the thread spawns makes it work as expected.

Here is the full backtrace:

value: 2.7416575
thread '<unnamed>' panicked at src/main.rs:16:69:
called `Option::unwrap()` on a `None` value
value: 2.7416575
stack backtrace:
   0: rust_begin_unwind
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/std/src/panicking.rs:645:5
   1: core::panicking::panic_fmt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/panicking.rs:72:14
   2: core::panicking::panic
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/panicking.rs:127:5
   3: core::option::Option<T>::unwrap
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/option.rs:931:21
   4: test_burn::main::{{closure}}
             at ./src/main.rs:16:46
   5: core::ops::function::FnOnce::call_once
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
thread 'main' panicked at src/main.rs:28:21:
called `Result::unwrap()` on an `Err` value: Any { .. }
stack backtrace:
   0: rust_begin_unwind
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/std/src/panicking.rs:645:5
   1: core::panicking::panic_fmt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/panicking.rs:72:14
   2: core::result::unwrap_failed
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/result.rs:1653:5
   3: core::result::Result<T,E>::unwrap
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/result.rs:1077:23
   4: test_burn::main
             at ./src/main.rs:28:5
   5: core::ops::function::FnOnce::call_once
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
giucesar commented 7 months ago

To add. Searching "register_grad" in the source code only returns the string defining this error message.

nathanielsimard commented 7 months ago

Reviewing the provided code, it becomes evident that the failure is expected with the current implementation. The issue lies in the fact that the root graph node is created only once (before the op is generated). Consequently, there is a single graph at that point. However, when you execute op twice, the graph is not splitted into two distinct entities, both operations share the same graph and modify it. It's important to note that you can only invoke the backward operation once on a graph, as the graph gets consumed in the process. While the same graph can be safely modified by multiple threads, it should not be consumed by both threads simultaneously.

The solution to this problem is to ensure that each thread has its own graph. This can easily be done with the detach method.

use std::{thread::sleep, time::Duration};

use burn::{backend::Autodiff, tensor::Tensor};

type B = Autodiff<burn::backend::NdArray>;

fn obj(x: Tensor<B, 1>) -> Tensor<B, 1> {
    (x.clone() * x.clone()).sum().sqrt()
}

fn main() {
    let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0]).set_require_grad(true);
    let op = move || {
        // detach will create a new graph here.
        let x = x.detach();
        // Same as before.
        let value = obj(x.clone()).sum();
        let grads = value.backward();
        let g_x = Tensor::<B, 1>::from_inner(x.clone().grad(&grads).unwrap());

        let x_1 = x.clone() - g_x;
        let value = obj(x_1.clone()).sum();
        println!("value: {}", value.into_scalar());
    };

    op.clone()();

    let thread_1 = std::thread::spawn(op.clone());
    // sleep(Duration::from_millis(100));
    let thread_2 = std::thread::spawn(op.clone());
    thread_1.join().unwrap();
    thread_2.join().unwrap();
}

Note that sometimes your code may work as intended because the creation of the graph is actually lazy. So, when a graph doesn't exist, we create a new one. This is why you can call your code multiple times sequentially, and it will work.

giucesar commented 7 months ago

It makes sense, thanks. That is pretty hard to identify without knowing the internals of the library. Honestly, I had an incorrect understanding of the detach operation.

nathanielsimard commented 7 months ago

To be fair, I believe we should enhance the error message and generate additional documentation on how to use multithreading with autodiff. Then, we could provide a link to the documentation within the error message, enhancing its usefulness.

giucesar commented 7 months ago

Different scenario for similar problem. Now, the parameter is passed as mutable to the op. For some reason, calling detach doesn't work anymore; it needs to add require_grad to it. Meanwhile, as expected, without parallelism, it works without any problem. I still have inconsistencies in the gradients for threaded training, but I am struggling to build a small example.

use burn::{backend::Autodiff, tensor::Tensor};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};

type B = Autodiff<burn::backend::NdArray>;

fn obj(x: Tensor<B, 1>) -> Tensor<B, 1> {
    (x.clone() * x.clone()).sum().sqrt()
}

fn main() {
    let mut models = (0..100)
        .map(|_| {
            Tensor::<B, 1>::random([3], burn::tensor::Distribution::Normal(0., 1.))
                .set_require_grad(true)
        })
        .collect::<Vec<_>>();

    let op = |model: &mut Tensor<B, 1>| {
        let md = model.clone().detach();
        // let md = model.clone().detach().require_grad();
        let value = obj(md.clone()).sum();
        let grads = value.backward();
        let g_x = Tensor::<B, 1>::from_inner(md.clone().grad(&grads).unwrap());

        *model = md.clone() - g_x;
        let value = obj(model.clone()).sum();
        println!("value: {}", value.into_scalar());
    };

    models.iter_mut().for_each(op);
    models.par_iter_mut().for_each(op);
}
nathanielsimard commented 7 months ago

You are updating each tensor in the list; however, the newly updated tensors aren't tagged as required grad. Therefore, in the second iteration, you need to call require_grad. Keep in mind that you aren't just updating the tensor value with the mutation, but all the tensor state that comes with it. It's like creating new parameters at each new iteration; this is why a new require_grad is necessary.

giucesar commented 7 months ago

Understood. I incorrectly assumed it kept the require_grad tag. My code was previously written with pure tch, and it had the += operator, which keeps the gradient tag.

nathanielsimard commented 7 months ago

@giucesar, I think it's going to be a common mistake, most tensor libraries aren't perfectly stateless in their behavior. We created an issue to better document it in the introduction, and I think we should show some gotcha examples. Issue: #1153