tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.83k stars 433 forks source link

Cpu/Cuda conversion issue in Candle backend during batchnorm layer #1065

Closed michael-temp closed 8 months ago

michael-temp commented 10 months ago

Hi there,

burn = { version = "0.11.1", default-features = false, features = ["train", "ndarray", "cuda", "candle"] }

Gave this error:

thread 'main' panicked at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9:
=== Tensor Operation Error ===
  Operation: 'Add'
  Reason:
    1. The provided tensors are not on the same device. Lhs tensor device Cpu, Rhs tensor device Cuda(0). 

stack backtrace:
   0: rust_begin_unwind
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/std/src/panicking.rs:597:5
   1: core::panicking::panic_fmt
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:72:14
   2: core::panicking::panic_display
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panicking.rs:178:5
   3: burn_tensor::tensor::api::float::<impl burn_tensor::tensor::api::base::Tensor<B,_>>::matmul::panic_cold_display
             at /rustc/bf9a1c8a193fc373897196321215794c8bebbeec/library/core/src/panic.rs:99:13
   4: burn_tensor::tensor::api::numeric::<impl burn_tensor::tensor::api::base::Tensor<B,_,K>>::add
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.11.1/src/tensor/api/numeric.rs:41:9
   5: burn_core::nn::norm::batch::BatchNorm<B,_>::forward_train
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:133:28
   6: burn_core::nn::norm::batch::BatchNorm<B,_>::forward
             at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.11.1/src/nn/norm/batch.rs:85:21
   ...

It seems to be raised from this statement:

let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
    mean.clone()
        .detach()
        .mul_scalar(self.momentum)
        .reshape([channels]),
);
louisfd commented 10 months ago

This looks like the kind of bug that occurs because some methods always use the default device instead of the actual one... Default device stuff is under heavy refactoring in #1081 so I suggest we wait for it to merge

antimora commented 8 months ago

https://github.com/tracel-ai/burn/pull/1081 has been merged and I assume this will fix. Closing for now. Let us know if this problem still occurs.