jafioti / luminal

Deep learning at the speed of light.
Apache License 2.0
1.38k stars 86 forks source link

Autograd add grad error #68

Open swfsql opened 1 week ago

swfsql commented 1 week ago

Hi, I've found a runtime error and made this minimized example:

fn test_add_grad_decreasing_idx() {
    let mut cx = Graph::new();
    let a: GraphTensor<R1<2>> = cx.tensor();
    let a: GraphTensor<R3<1, 1, 2>> = a.expand::<_, LAxes2<0, 1>>();
    let a: GraphTensor<R3<2, 1, 1>> = a.permute::<_, LAxes3<2, 1, 0>>();
    // a.shape.fake = [false, true, true]
    // a.shape.indexes = [0, 2, 1] // note that the idx isn't necessarily increasing (0,1,2)
    let b: GraphTensor<R3<2, 1, 1>> = cx.tensor();
    let weights = vec![a.id, b.id];

    let m: GraphTensor<R3<2, 1, 1>> = a * b;
    let loss: GraphTensor<R0> = m.sum_reduce();
    let _grads = cx.compile(Autograd::new(weights, loss), ());

The error:

thread 'autograd::tests::test_add_grad_decreasing_idx' panicked at /usr/local/cargo/registry/src/index.crates.io-6f17d22bba15001f/tinyvec-1.6.0/src/arrayvec.rs:681:26:
index out of bounds: the len is 0 but the index is 0
stack backtrace:
   3: tinyvec::arrayvec::ArrayVec<A>::remove
             at /usr/local/cargo/registry/src/index.crates.io-6f17d22bba15001f/tinyvec-1.6.0/src/arrayvec.rs:681:26
   4: luminal::shape::tracker::ShapeTracker::remove_dim
             at /workspaces/coursera-deep-learning-specialization/luminal_original/src/shape/tracker.rs:58:21
   5: luminal_training::autograd::add_grad
             at ./src/autograd.rs:208:13
   6: <luminal_training::autograd::Autograd as luminal::compiler_utils::Compiler>::compile
             at ./src/autograd.rs:113:21

For what I've noticed, the add_grad function may assume that the fwd shape.indexes is always increasing, but in this case, thanks to how the a.permute axes were defined (LAxes3<2, 1, 0>), the indexes may end up decreasing [0, 2, 1] - to note, if the axes were defined as LAxes3<2, 0, 1>, the indexes would not decrease and there would be no error. Since this shape also has two fake axes (the fake is [false, true, true]), at the last axis iteration, the previous fake axis removal changes the axes length, and so the last iteration goes out of bounds.

Relevant code section: https://github.com/jafioti/luminal/blob/f61d53f859293d4c5e4d57d264dfc423a98007e0/crates/luminal_training/src/autograd.rs#L202-L211

I'm not sure how to solve this as I'm still learning how the shapes works in overall, but I'll try to experiment more. I've also linked a PR which shows the example and error.

swfsql commented 1 week ago

On the PR I've simplified the test, added another test and then added a draft fix