tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.9k stars 440 forks source link

ndarray: could not broadcast array from shape: [2] to: [1] in item.loss.backward() #680

Closed L-M-Sherlock closed 1 year ago

L-M-Sherlock commented 1 year ago

Describe the bug

My model can be trained when the batch size is one. But a fatal error happened when I set batch_size = 2.

=== PANIC ===
A fatal error happened, you can check the experiment logs here => './tmp/fsrs/experiment.log'
=============
thread 'training::test' panicked at 'ndarray: could not broadcast array from shape: [2] to: [1]', /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13
stack backtrace:
   0: rust_begin_unwind
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/std/src/panicking.rs:578:5
   1: core::panicking::panic_fmt
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/panicking.rs:67:14
   2: ndarray::ArrayBase<S,D>::broadcast_unwrap::broadcast_panic
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13
   3: ndarray::ArrayBase<S,D>::broadcast_unwrap
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1538:21
   4: ndarray::impl_methods::<impl ndarray::ArrayBase<S,D>>::zip_mut_with
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/impl_methods.rs:2456:33
   5: ndarray::impl_methods::<impl ndarray::ArrayBase<S,D>>::assign
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/impl_methods.rs:2355:9
   6: burn_ndarray::ops::base::NdArrayOps<E>::slice_assign
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-ndarray/src/ops/base.rs:48:9
   7: burn_ndarray::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_ndarray::backend::NdArrayBackend<E>> for burn_ndarray::backend::NdArrayBackend<E>>::slice_assign
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-ndarray/src/ops/tensor.rs:200:9
   8: <burn_autodiff::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_autodiff::backend::ADBackendDecorator<B>> for burn_autodiff::backend::ADBackendDecorator<B>>::slice::Index<_> as burn_autodiff::ops::backward::Backward<B,_,1_usize>>::backward::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/tensor.rs:636:21
   9: burn_autodiff::ops::backward::unary
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/backward.rs:78:20
  10: <burn_autodiff::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_autodiff::backend::ADBackendDecorator<B>> for burn_autodiff::backend::ADBackendDecorator<B>>::slice::Index<_> as burn_autodiff::ops::backward::Backward<B,_,1_usize>>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/tensor.rs:634:17
  11: <burn_autodiff::ops::base::OpsStep<B,T,SB,_,_> as burn_autodiff::graph::base::Step>::step
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/base.rs:142:9
  12: burn_autodiff::graph::backward::execute_steps::{{closure}}::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:35:61
  13: core::iter::traits::iterator::Iterator::for_each::call::{{closure}}
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:854:29
  14: core::iter::traits::iterator::Iterator::fold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:2482:21
  15: core::iter::traits::iterator::Iterator::for_each
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:857:9
  16: burn_autodiff::graph::backward::execute_steps::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:35:27
  17: core::iter::traits::iterator::Iterator::for_each::call::{{closure}}
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:854:29
  18: core::iter::traits::double_ended::DoubleEndedIterator::rfold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/double_ended.rs:307:21
  19: <core::iter::adapters::rev::Rev<I> as core::iter::traits::iterator::Iterator>::fold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/adapters/rev.rs:64:9
  20: core::iter::traits::iterator::Iterator::for_each
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:857:9
  21: burn_autodiff::graph::backward::execute_steps
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:33:5
  22: burn_autodiff::graph::backward::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:11:5
  23: <burn_autodiff::backend::ADBackendDecorator<B> as burn_tensor::tensor::backend::base::ADBackend>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/backend.rs:42:9
  24: burn_tensor::tensor::api::float::<impl burn_tensor::tensor::api::base::Tensor<B,_>>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-tensor/src/tensor/api/float.rs:287:9
  25: fsrs_optimizer_rs::training::<impl burn_train::learner::train_val::TrainStep<fsrs_optimizer_rs::dataset::FSRSBatch<B>,burn_train::learner::classification::ClassificationOutput<B>> for fsrs_optimizer_rs::model::Model<B>>::step
             at ./src/training.rs:56:32
  26: burn_train::learner::epoch::TrainEpoch<TI>::run
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-train/src/learner/epoch.rs:108:24
  27: burn_train::learner::train_val::<impl burn_train::learner::base::Learner<B,M,O,LR,TO,VO>>::fit
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-train/src/learner/train_val.rs:131:21
  28: fsrs_optimizer_rs::training::train
             at ./src/training.rs:131:29
  29: fsrs_optimizer_rs::training::test
             at ./src/training.rs:157:5
  30: fsrs_optimizer_rs::training::test::{{closure}}
             at ./src/training.rs:149:11
  31: core::ops::function::FnOnce::call_once
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/ops/function.rs:250:5
  32: core::ops::function::FnOnce::call_once
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

To Reproduce

My model is complicated, so I haven't find a minimal script to reproduce the error.

For details, please see: https://github.com/open-spaced-repetition/fsrs-optimizer-burn/pull/16#issuecomment-1689326326

The code reproducing error: https://github.com/open-spaced-repetition/fsrs-optimizer-burn/pull/16/commits/23b8772eb2e8c56b1cda40cb66b08aa5d20772c8

nathanielsimard commented 1 year ago

@L-M-Sherlock does it work with other backends?

L-M-Sherlock commented 1 year ago

@L-M-Sherlock does it work with other backends?

It works with burn_wgpu. I haven't tested burn_tch because building torch-sys costs too much time.

L-M-Sherlock commented 1 year ago

I tested the tch backend. It was blocked by calculating loss.

nathanielsimard commented 1 year ago

Fixed with https://github.com/burn-rs/burn/issues/686