coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks
Other
1.68k stars 98 forks source link

Add QoL Error when dimensions between layers don't match #521

Open kstavro opened 1 year ago

kstavro commented 1 year ago

If, by accident, you mistype the dimensions of layers in your model, the compiler throws you an error at model.forward_mut() which does not hint at all to the actual error (mismatched dimensions).

Would be great if the compiler hints directly to the mismatched dims (not entirely sure yet how to implement this in Rust, but would be willing to look into it).

coreylowman commented 1 year ago

Yeah I see the confusion, good request

kstavro commented 1 year ago

I tried to see if it is possible to tackle this, but I have the feeling it might not be easy get around it. The reason is that it boils down to a way of changing the compiler suggestion that a trait bound is not satisfied.

More precisely, here is eg the implementation of Module for Linear:

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for Linear<I, O, E, D>
where
    T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>,
    T::Tape: Tape<E, D>,
    for<'a> Bias1D<'a, O, E, D>: Module<T::Output, Output = T::Output, Error = D::Err>,
{
...
}

When you have mismatched dimensions between layers, the TryMatMul trait from the first where condition is not satisfied and the compiler suggests just that. You can still find in the problem buried inside a rather big compiler error (at model.try_forward_mut()) in the form of required for dfdx::nn::modules::Linear<512, 128, f32, dfdx::tensor::Cpu> to implement Module<Tensor<(dfdx::shapes::Const<32>, dfdx::shapes::Const<511>), f32, dfdx::tensor::Cpu, OwnedTape<f32, dfdx::tensor::Cpu>>> (the mismatch is 512 and 511), but still.

I tried to find a way to bypass the compiler suggestion for a missing trait bound, but from googling a bit, I didn't find anything relevant. Any ideas for an easy solution are welcome.

coreylowman commented 1 year ago

Yeah I suspect the error might be better if we move away from the TryMatMul bound, and instead of multiple impl Module for fixed tensor sizes (like it was previously). Would rather have better error messages than slightly less code

Can you try to move to fixed impls and see if the error is any better?

impl Module<Tensor<Rank1<I>, ...> for Linear<I, O, E, D> { }
impl Module<Tensor<(Batch, Const<I>), ...> for Linear<I, O, E, D> { }
impl Module<Tensor<(Batch, Seq, Const<I>), ...> for Linear<I, O, E, D> { }
coreylowman commented 1 year ago

There's a similarly arcane error message if you mess up your conv/bn/residuals in a conv net. That might also be addressed with fixed impls

kstavro commented 1 year ago

Ok, then if we are open to small architectural changes like the one you suggested, we might have more options.

I also considered creating a new DimError and/or trait HasDimErr, but nothing could get around the compiler focusing on the trait bound not being satisfied (the idea being that such an error might also be relevant for other operations that might need matching dimensions)

Will check it later tonight or tomorrow morning. Thanks for the suggestion!

coreylowman commented 1 year ago

Yeah for internal library code, ergonomics/usability is higher priority than minifying code. Its definitely a balance, but better error messages probably win every time TBH

kstavro commented 1 year ago

Didn't try the fixed implementations yet, but my feeling is the errors will be similar. The problem is that all current MatMul traits/impls involve only 3 dimensions, because they assume that the left column dimension will be equal to the right row dim. To avoid the "trait bound is not satisfied" error, I think the best solution is to associate 4 Dims with TryMatMul and check if the desired dims match (it is a TryMutMal after all).

Eg, with the following relaxed impl (nvm about the names, we can change them to your preference)

impl<M: Dim, LeftK: Dim, RightK: Dim, N: Dim, E: Dtype, D: MatMatKernel<E>, T, R>
    TryMatMul<Tensor<(RightK, N), E, D, R>> for Tensor<(M, LeftK), E, D, T>
where
    T: Tape<E, D> + Merge<R>,
    R: Tape<E, D>,
{
    type Output = Tensor<(M, N), E, D, T>;
    /// ```compile_fail
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// let x: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
    /// let y: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
    /// let _: Tensor<Rank2<3, 4>, f32, _> = x.try_matmul(y);
    /// ```
    fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result<Self::Output, Self::Err> {
        assert_eq!(self.shape.1.size(), rhs.shape.0.size());
        try_binary_op(self, rhs, D::forward, D::backward)
    }
}

we can get the code to hit on the assert_eq! (runtime for now).

thread 'main' panicked at 'assertion failed: `(left == right)`
  left: `511`,
 right: `512`', 

If you don't find that it is not too much to relax the dims from 3 to 4, it should be possible to check correctly if the dims match and make TryMatMul complain about the dimensions not matching (and mostly keep it as is). Thoughts?

coreylowman commented 1 year ago

Interesting idea, and we could probably get this to happen at compile time with some clever error post monomorphize error checking:

trait AssertDimEq<Rhs> {
    const TYPE_CHECK: ();
    fn assert_dim_eq(&self, rhs: &Rhs);
}

impl<const M: usize, const N: usize> AssertDimEq<Const<N>> for Const<M> {
    const TYPE_CHECK: () = assert_eq!(M, N);
    fn assert_dim_eq(&self, rhs: &Rhs) {
        let _ = Self::TYPE_CHECK;
    }
}

// these would all do runtime checks
impl AssertDimEq<usize> for usize { ... }
impl AssertDimEq<Const<M>> for usize { ... }
impl AssertDimEq<usize> for Const<M> { ... }

plus this would allow you to do matmul with dims that are equal but not necessarily the same type!

kstavro commented 1 year ago

Interesting idea, and we could probably get this to happen at compile time with some clever error post monomorphize error checking:

trait AssertDimEq<Rhs> {
    const TYPE_CHECK: ();
    fn assert_dim_eq(&self, rhs: &Rhs);
}

impl<const M: usize, const N: usize> AssertDimEq<Const<N>> for Const<M> {
    const TYPE_CHECK: () = assert_eq!(M, N);
    fn assert_dim_eq(&self, rhs: &Rhs) {
        let _ = Self::TYPE_CHECK;
    }
}

// these would all do runtime checks
impl AssertDimEq<usize> for usize { ... }
impl AssertDimEq<Const<M>> for usize { ... }
impl AssertDimEq<usize> for Const<M> { ... }

Yeah, this is more or less what I had in mind, but I thought we might make already compile checks (maybe I overestimate things, not sure yet). Your impls are though definitely more thorough and your names are better. One detail that I didn't mention is that this trait bound needs to be implemented already inside the MatMalKernel (in Kernels in general) and not the TryMatMul, I already had to associate 4 Dim types in the example impl I posted, in order for the trait bound not to hit, but this should hopefully keep the error message concise enough.

One more thing: the assert_eq! above inside try_matmul was already there, I just used it to argue that with the 4 dimensions we can bypass the weird trait bound and control the error elsewhere, and wanted to see how you liked the idea before going and implementing it correctly with traits. However, this assert_eq! check does seem a bit redundant to me, as both the actual code as well as whatever changed will come out of this issue, will already have checked this. I have seen these dimension asserts in other ops as well. Is there a logic behind them that I miss? Or simply relics of previous code versions?

plus this would allow you to do matmul with dims that are equal but not necessarily the same type!

That does sound like it might make the stacking of layers/components more flexible/dynamic! Any easy examples were this might be relevant? Can't come up with one from the top of my head.

coreylowman commented 1 year ago

Is there a logic behind them that I miss?

It's all for usize dimensions, which have to be checked at runtime. Previously when only Const dimensions were supported, there weren't any assert_eqs on dimensions. So for any op that accepts more than 1 tensor with the same type of dimension, we have to make sure the actual values match.

That does sound like it might make the stacking of layers/components more flexible/dynamic! Any easy examples were this might be relevant? Can't come up with one from the top of my head.

Yeah not off the top of my head. The only one maybe is if you've sent one of matrices to dynamic dimensions in a transformer maybe?

kstavro commented 1 year ago

let _ = Self::TYPE_CHECK; doesn't compile (type annotations needed cannot infer the value of const parameter N), and no matter what I googled, couldn't make it work.

So I started getting creative and at some point also removed the assert_eq! for the various dimension matchin assertions. Something very weird happened: MNIST trained without matching dimensions (still LeftK=511 and RightK=512). And it trained just fine, no panic, only very slight discrepancies in the losses. Tried it with different dimensions at all layers, still trains (with a bit worse losses)!!! Are you aware of this? I wonder if the unsafe part of matrix multiply still works even if the dimensions don't match...

coreylowman commented 1 year ago

Yeah the matmul kernels just pull from one of the ks and use that as the dim. I guess you happened to pull the right one? And yeah the unsafe matmul stuff at the lowest level probably won't check on that 😀

Ahh yes I see the issue you had. This should work instead:

struct Const<const M: usize>;

trait AssertDimEq<Rhs> {
    const TYPE_CHECK: ();
    fn assert_dim_eq(&self, rhs: &Rhs);
}

impl<const M: usize, const N: usize> AssertDimEq<Const<N>> for Const<M> {
    const TYPE_CHECK: () = assert!(M == N);
    fn assert_dim_eq(&self, rhs: &Const<N>) {
        let _ = <Self as AssertDimEq<Const<N>>>::TYPE_CHECK;
    }
}

fn main() {
    Const::<5>.assert_dim_eq(&Const::<3>);
}

I believe this is because with Self::TYPE_CHECK, it doesn't know whether Rhs is Const<M> or Const<N>

kstavro commented 1 year ago

Self as AssertDimEq<Const<N>> did the trick, thanks! I tried Self as Const<M>, but didn't work, such rust noob! Now, the error looks as follows (will improve it, just wanted to see if we are on the right path):

PS C:\Users\...\dfdx> cargo run --release --features threaded-cpu --example 06-mnist -- .\examples\tmp\
   Compiling dfdx v0.10.0 (C:\Users\...\dfdx)
error[E0080]: evaluation of `<dfdx::shapes::Const<511> as dfdx::tensor_ops::matmul::AssertDimEq<dfdx::shapes::Const<512>>>::TYPE_CHECK` failed
   --> C:\Users\...\dfdx\src\tensor_ops\matmul\mod.rs:178:28
    |
178 |     const TYPE_CHECK: () = assert!(M == N);
    |                            ^^^^^^^^^^^^^^^ the evaluated program panicked at 'assertion failed: M == N', C:\Users\...\dfdx\src\tensor_ops\matmul\mod.rs:178:28
    |
    = note: this error originates in the macro `assert` (in Nightly builds, run with -Z macro-backtrace for more info)

note: the above error was encountered while instantiating `fn <dfdx::shapes::Const<511> as dfdx::tensor_ops::matmul::AssertDimEq<dfdx::shapes::Const<512>>>::assert_dim_eq`
   --> C:\Users\...\dfdx\src\tensor_ops\matmul\mod.rs:225:9
    |
225 |         self.shape.1.assert_dim_eq();
    |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Both the error and the note sort of already contain the problem and there is no arcane compiler error any more. Now it is only a matter of improving the errors.

So, this error was an evaluation error, which was caught during compile time, but still had to run the example to see the error. This is already better than runtime, but it still doesn't make rust-analyzer complain, so I guess it is somewhere between compile and runtime? Or I have done some dumb thing inside the many changes I made and will go away after I clean things up.

After working on this, I realized it is not straightforward to make static assertions work and why you had also let some cases panic at runtime. But I had the hope that at least when the dimensions are constant, we could get everything down to compile... Don't yet fully understand all limitations of Rust's compiler checks or if something more is possible, but at least passing 4 associated dims in the matmuls combined with your suggestion seems to create a checkpoint of improvement in the errors.

coreylowman commented 1 year ago

Yeah you're right that rust-analyzer won't catch it unfortunately. The place I found this technique mentioned that this sort of type checking isn't suggested for this very reason (i.e. cargo check won't catch it).

I'm not actually sure matmul is the place where we want to catch this error after all. Not all users would necessarily know that Linear uses matmul under the hood, and to associate this matmul error with one of their linear layers.

kstavro commented 1 year ago

I'm not actually sure matmul is the place where we want to catch this error after all. Not all users would necessarily know that > Linear uses matmul under the hood, and to associate this matmul error with one of their linear layers.

I see what you mean, though this is the current implementation of Linear:

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for Linear<I, O, E, D>
where
    T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>,
    T::Tape: Tape<E, D>,
    for<'a> Bias1D<'a, O, E, D>: Module<T::Output, Output = T::Output, Error = D::Err>,
{
    type Output = T::Output;
    type Error = D::Err;

    /// 1d forward using [matmul()] and [add()].
    fn try_forward(&self, x: T) -> Result<Self::Output, D::Err> {
        let o = x.try_matmul(self.weight.retaped::<T::Tape>().try_permute()?)?;
        Bias1D { beta: &self.bias }.try_forward(o)
    }
}

The highest level we can lift the typecheck is TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape> or we have to change the trait bounds of T, right? The question is whether a message like dimensions don't match, check if your layers or tensor dimensions match (which is what I had in mind as the very least baseline solution) is an acceptable fallback. Will play a bit around with different options and report back a bit what seems to work best.

coreylowman commented 1 year ago

Yeah we can always split this single generic impl into multiple impls. For example, If i change the impl to be

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
    Module<Tensor<Rank1<I>, E, D, T>> for Linear<I, O, E, D>
{
    type Output = Tensor<Rank1<O>, E, D, T>;
    type Error = D::Err;

    /// 1d forward using [matmul()] and [add()].
    fn try_forward(&self, x: Tensor<Rank1<I>, E, D, T>) -> Result<Self::Output, D::Err> {
        let o = x.try_matmul(self.weight.retaped::<T>().try_permute()?)?;
        Bias1D { beta: &self.bias }.try_forward(o)
    }
}

then for wrong input size I get this error:

type Model = (Linear<5, 10>, Linear<10, 5>);
let m = dev.build_module::<Model, f32>();
let x = dev.sample_normal::<Rank1<6>>(); // wrong input size
let y = m.forward(x);

error:

error[E0308]: mismatched types
  --> examples\tmp.rs:9:23
   |
9  |     let y = m.forward(x);
   |               ------- ^ expected `5`, found `6`
   |               |
   |               arguments to this function are incorrect
   |
   = note: expected struct `Tensor<(Const<5>,), f32, _, _>`
              found struct `Tensor<(Const<6>,), _, _, NoneTape>`
note: associated function defined here
  --> C:\Users\clowm\Documents\programming\dfdx\src\nn\module.rs:20:8
   |
20 |     fn forward(&self, input: Input) -> Self::Output {

But there's still a pretty awful error message if I mess up an internal layer:

type Model = (Linear<5, 10>, Linear<11, 5>); // second linear has wrong input size
let m = dev.build_module::<Model, f32>();
let x = dev.sample_normal::<Rank1<5>>();
let y = m.forward(x);

error:

error[E0599]: the method `forward` exists for tuple `(dfdx::nn::modules::Linear<5, 10, f32, dfdx::tensor::Cpu>, dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu>)`, but its trait bounds were not satisfied
  --> examples\tmp.rs:8:15
   |
8  |     let y = m.forward(x);
   |               ^^^^^^^ method cannot be called on `(dfdx::nn::modules::Linear<5, 10, f32, dfdx::tensor::Cpu>, dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu>)` due to unsatisfied trait bounds
   |
  ::: C:\Users\clowm\Documents\programming\dfdx\src\nn\linear.rs:47:1
   |
47 | pub struct Linear<const I: usize, const O: usize, E: Dtype, D: DeviceStorage> {
   | -----------------------------------------------------------------------------
   | |
   | doesn't satisfy `<_ as dfdx::nn::Module<Tensor<(Const<10>,), f32, dfdx::tensor::Cpu, _>>>::Error = CpuError`
   | doesn't satisfy `_: dfdx::nn::Module<Tensor<(Const<10>,), f32, dfdx::tensor::Cpu, _>>`
   |
   = note: the following trait bounds were not satisfied:
           `<dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu> as dfdx::nn::Module<Tensor<(Const<10>,), f32, dfdx::tensor::Cpu, _>>>::Error = CpuError` 
           which is required by `(dfdx::nn::modules::Linear<5, 10, f32, dfdx::tensor::Cpu>, dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu>): dfdx::nn::Module<Tensor<(Const<5>,), f32, dfdx::tensor::Cpu, _>>`
           `dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu>: dfdx::nn::Module<Tensor<(Const<10>,), f32, dfdx::tensor::Cpu, _>>`
           which is required by `(dfdx::nn::modules::Linear<5, 10, f32, dfdx::tensor::Cpu>, dfdx::nn::modules::Linear<11, 5, f32, dfdx::tensor::Cpu>): dfdx::nn::Module<Tensor<(Const<5>,), f32, dfdx::tensor::Cpu, _>>`
kstavro commented 1 year ago

I think you get the arcane error because you haven't associated yet 4 dim types in matmul?

I can split the generic type inside Module into concrete types. Let me see how the error in the above code looks in my branch.

kstavro commented 1 year ago

I am trying to lift the error to the Linear level, but I am struggling a bit with understanding why the following is happening for the mlp of mnist where I make one mismatch to the dimensions. I am checking shapes now in the same spirit as before (as always, nevermind the names, I also haven't shared all details to keep the wall of text smaller).

The trait implemented for Rank2

pub trait AssertLayerMatch<Rhs: Shape> {
    const TYPE_CHECK: ();
    fn assert_dim_eq(&self);
}

impl<const I: usize, const O: usize, const IN: usize, const OUT: usize>
    AssertLayerMatch<Rank2<I, O>> for Rank2<IN, OUT>

Implementing the Trait in T's shape to be able to check the shapes

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for Linear<I, O, E, D>
where
    T: SplitTape
        // + AssertNextLayerMatch<T>
        + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>>
        + HasErr<Err = D::Err>
        + HasShape,
    T::Tape: Tape<E, D>,
    T::Shape: AssertLayerMatch<Rank2<I, O>>,  // added the shape check here
    for<'a> Bias1D<'a, O, E, D>: Module<T::Output, Output = T::Output, Error = D::Err>,

Messing up one dimension on purpose

type Mlp = (
    Linear<784, 511>, // here is the mismatch again
    Linear<512, 128>,
    Linear<128, 32>,
    Linear<32, 10>,
);

error[E0080]: evaluation of `<(dfdx::shapes::Const<32>, dfdx::shapes::Const<511>) as dfdx::nn::linear::AssertLayerMatch<(dfdx::shapes::Const<512>, dfdx::shapes::Const<128>)>>::TYPE_CHECK` failed
  --> C:\Users\dfdx\src\nn\linear.rs:37:28
   |
37 |       const TYPE_CHECK: () = assert!(
   |  ____________________________^
38 | |         OUT == I,
39 | |         "You are trying to stack tensors/layers, whose outgoing and ingoing dimensions do not match",
40 | |     );
   | |_____^ the evaluated program panicked at 'You are trying to stack tensors, whose outgoing and ingoing dimensions do not match', C:\Users\dfdx\src\nn\linear.rs:37:28
   |
   = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `assert` (in Nightly builds, run with -Z macro-backtrace for more info)

note: the above error was encountered while instantiating `fn <(dfdx::shapes::Const<32>, dfdx::shapes::Const<511>) as dfdx::nn::linear::AssertLayerMatch<(dfdx::shapes::Const<512>, dfdx::shapes::Const<128>)>>::assert_dim_eq`
   --> C:\Users\dfdx\src\nn\linear.rs:162:9
    |
162 |         x.shape().assert_dim_eq();
    |         ^^^^^^^^^^^^^^^^^^^^^^^^^

For more information about this error, try `rustc --explain E0080`.

How is it possible that I get the error for (dfdx::shapes::Const<32>, dfdx::shapes::Const<511>)? There is no layer like that inside the model. I guess tensors of subsets of layers materialize into intermediate tensors but (32, 511) doesn't make sense from the structure of the model.

What do you think of this kind of error? Apart from the dims right now not making any sense at all and not being able to have the dimensions inside the assertion message (rust complains when I try to use I, O, etc inside the message), I kind of like it compared to the previous state. I wish there was a way to make it a true compile error...

coreylowman commented 1 year ago

Do you happen to be using a batch size of 32? Note

<
(dfdx::shapes::Const<32>, dfdx::shapes::Const<511>)
as
dfdx::nn::linear::AssertLayerMatch<(dfdx::shapes::Const<512>, dfdx::shapes::Const<128>)
>>::TYPE_CHECK` failed

is actually checking the Linear<512, 128> layer against input (32, 511)

coreylowman commented 1 year ago

I think we could just reverse the check so that the layer sizes appear first:

Rank2<I, O>: AssertLayerMatch<T::Shape>

But yes this is much improved over previous errors for sure! Nice progress!

kstavro commented 1 year ago

Thanks! I am still obsessing a bit about compile checks, but at least we might get to an acceptable fallback if nothing else works.

Do you happen to be using a batch size of 32? Note

<
(dfdx::shapes::Const<32>, dfdx::shapes::Const<511>)
as
dfdx::nn::linear::AssertLayerMatch<(dfdx::shapes::Const<512>, dfdx::shapes::Const<128>)
>>::TYPE_CHECK` failed

is actually checking the Linear<512, 128> layer against input (32, 511)

Yep, it was indeed batch size, nice catch! So now I see what happens with the way I implemented the trait and the way that forward works:

Not sure how to really check only between layers. Not clear to me how to compare Linears (maybe I can compare their weight shapes?). Will give it some thought tomorrow.

Yep, I can surely reverse the check, I just thought checking the first layer Linear<784, 511> against the second Linear<512, 128> (well, sort of) and so on, felt more natural, but the error is about shapes (of inputs), not layers, so you might be right.

kstavro commented 1 year ago

I was browsing through the changes of v0.11.0 and noticed #543 . The arcane errors you tried to fix there looked similar to the initial errors with the mismatched dimensions. Did you manage 100% compile error checks there? Are they caught by rust-analyzer? Should I draw inspiration from what you did there?

coreylowman commented 1 year ago

Not sure how to really check only between layers. Not clear to me how to compare Linears (maybe I can compare their weight shapes?). Will give it some thought tomorrow.

Yeah I think this is hard, since the output of a layer depends on the input from all previous layers. Linear is a bit of a special case since it encodes the size on the type, but Conv2D for example you need the input shape to properly type check.

I was browsing through the changes of v0.11.0 and noticed https://github.com/coreylowman/dfdx/pull/543 .

Yeah the fix in 543 used the approach we looked at above, so cargo check & rust-analyzer don't error out like they should.

kstavro commented 1 year ago

Ok, thanks for clarifying, the commit read Moving Reshape to use stable compile time asserts, so I wasn't sure if it was the exact same thing we looked here, as this approach isn't 100% compile time.

I take this issue is still relevant, right? I have noticed a lot of changes have recently happened in the direction of compile vs runtime checks and not 100% sure that this issue isn't already superseded/fixed by all other related changes.

coreylowman commented 1 year ago

Ah yeah, so the type checking still happens at compile time (cargo build will fail as it should), but its at a later phase after cargo check/rust-analyzer. It's a weird distinction.

But yes this issue is definitely still relevant!

kstavro commented 1 year ago

Planning to finish this during the weekend.

kstavro commented 1 year ago

Small update from my side: I merged all the recent PRs into my branch and now I am getting a lot of errors that look like this just by running the mnist example (the compiler just checks the whole repo I guess, so it finds mistakes in the linears of multiheaded attentions of the transformers as well):

error[E0277]: the trait bound `(S2, shape::Const<M>): AssertLayerMatch<(shape::Const<M>, shape::Const<V>)>` is not satisfied
   --> src\nn\transformer\mha.rs:116:38
    |
116 |         let v = self.w_v.try_forward(v.retaped::<T>())?;
    |                          ----------- ^^^^^^^^^^^^^^^^ the trait `AssertLayerMatch<(shape::Const<M>, shape::Const<V>)>` is not implemented for `(S2, shape::Const<M>)`
    |                          |
    |                          required by a bound introduced by this call
    |
note: required for `linear::Linear<M, V, E, D>` to implement `module::Module<tensor_impls::Tensor<(S2, shape::Const<M>), E, D, T>>`

This didn't happen before merging the new PRs, but...

On the one hand, this is quite good because the way the trait is already implemented makes the compiler directly tell me all the relevant places where the trait needs to implemented. I planned looking the into the linears in the transformers case myself sort of manually after taking care of the mismatched dimensions for the vanilla MLPs, but now the compiler tells me directly where all the dimension checks need to take place. Which makes me quite pleased in that this PR was in the right direction because it seems it might have profited by the recent changes? And it will probably take care of many dim checks throughout the whole repo because the compiler won't let me miss them.

On the other hand:

Finally, I think I might rename the trait from AssertLayerMatch into AssertDimsMatch. If you already have any name preferences or suggestions, let me know.

coreylowman commented 1 year ago

now I am getting a lot of errors that look like this just by running the mnist example (the compiler just checks the whole repo I guess, so it finds mistakes in the linears of multiheaded attentions of the transformers as well):

Is this just because you haven't implemented the trait for all the different input shapes for Linear?

is there now an overkill in the checks? I recall you also implemented some dim check inside the transformers in a recent PR, but not sure exactly what? Is the above dim check already covered by your check, hence unnecessary for the attention linears of the transformers?

That was for usize dimension checks at runtime if I'm thinking of the same thing. The other compile time check was related to two of the const generics on the transformer struct, so not covered by this check.

After seeing in the errors all the relevant places where dim checks of linears now happen, I think I underestimated the amount of implementations and tests I need to write. So this will probably leak into tomorrow to be ready (I live in Europe, so most of the day is gone).

No rush at all!

Finally, I think I might rename the trait from AssertLayerMatch into AssertDimsMatch. If you already have any name preferences or suggestions, let me know.

Hmm maybe LinearLayerCheck or LinearCheck or LinearTypeCheck?

kstavro commented 1 year ago

Is this just because you haven't implemented the trait for all the different input shapes for Linear?

I think so, yeah. Probably there were only const dims inside the linear layers of mha in transformers when I first created the branch, but now there are dynamic dims for them.

That was for usize dimension checks at runtime if I'm thinking of the same thing. The other compile time check was related to two of the const generics on the transformer struct, so not covered by this check.

Yeah, so possibly now there is a "transformers dim check", but the linears inside attention also got non-constant dims so there was probably a linear dim check already from before.

Hmm maybe LinearLayerCheck or LinearCheck or LinearTypeCheck?

Roger 👍

kstavro commented 1 year ago

I think I now have most implementations, if not all. There is only one point that doesn't work after finishing the implementaions (it's basically the QK^T step):

error[E0599]: the method `try_matmul` exists for struct `Tensor<(B, usize, S1, usize), E, D, T>`, but its trait bounds were not satisfied
   --> src\nn\transformer\mha.rs:190:25
    |
190 |         let weights = q.try_matmul(k)?.try_mul(scalar)?;
    |                         ^^^^^^^^^^
    |
   ::: src\tensor\tensor_impls.rs:32:1
    |
32  | pub struct Tensor<S: Shape, E: Unit, D: DeviceStorage, T = NoneTape> {
    | --------------------------------------------------------------------
    | |
    | method `try_matmul` not found for this struct
    | doesn't satisfy `_: TryMatMul<Tensor<(B, usize, _, _), E, D, _>>`
    |
note: trait bound `(B, usize, S1, usize): MulDimCheck<(B, usize, _, _)>` was not satisfied
   --> src\tensor_ops\matmul\mod.rs:410:23
    |
405 |     TryMatMul<Tensor<(B, S, RightK, N), E, D, R>> for Tensor<(B, S, M, LeftK), E, D, T>
    |     ---------------------------------------------     ---------------------------------
...
410 |     (B, S, M, LeftK): MulDimCheck<(B, S, RightK, N)>,
    |                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound introduced here

Queries and Keys have dynamic dimensions, but not sure how to deal with them yet. Any insight why they have to be dynamic?

Changing the implementations of the kernels (and the respective TryMatMul) to have dynamic dims in the spirit of the following also doesn't seem to work. Rust doesn't complain about the trait and the implementation themselves, but it complains elsewhere, as it can't seem to infer the types of tensors with constant dimensions properly (unless I have done something terribly wrong).

Just so that you don't have to check it in detail, standard implementation just with dynamic dims.

 pub trait MatMatBatch4Kernel<E: Dtype>: DeviceStorage {
    fn forward<LeftK: Dim, RightK: Dim>(
        &self,
        lhs: &Tensor<(usize, usize, usize, LeftK), E, Self>,
        rhs: &Tensor<(usize, usize, RightK, usize), E, Self>,
    ) -> Result<Tensor<(usize, usize, usize, usize), E, Self>, Self::Err>
    where
        (usize, usize, usize, LeftK): MulDimCheck<(usize, usize, RightK, usize)>;

    fn backward<LeftK: Dim, RightK: Dim>(
        &self,
        lhs: &Tensor<(usize, usize, usize, LeftK), E, Self>,
        grad_lhs: &mut Self::Vec<E>,
        rhs: &Tensor<(usize, usize, RightK, usize), E, Self>,
        grad_rhs: &mut Self::Vec<E>,
        grad_out: &Self::Vec<E>,
    ) -> Result<(), Self::Err>
    where
        (usize, usize, usize, LeftK): MulDimCheck<(usize, usize, RightK, usize)>;
}

impl<LeftK: Dim, RightK: Dim, E: Dtype, D, T, R>
    TryMatMul<Tensor<(usize, usize, RightK, usize), E, D, R>>
    for Tensor<(usize, usize, usize, LeftK), E, D, T>
where
    D: MatMatBatch4Kernel<E>,
    T: Tape<E, D> + Merge<R>,
    R: Tape<E, D>,
    (usize, usize, usize, LeftK): MulDimCheck<(usize, usize, RightK, usize)>,
{
    type Output = Tensor<(usize, usize, usize, usize), E, D, T>;
    /// ```compile_fail
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// let x: Tensor<Rank4<1, 5, 3, 2>, f32, _> = dev.zeros();
    /// let y: Tensor<Rank4<1, 5, 3, 4>, f32, _> = dev.zeros();
    /// let _: Tensor<Rank3<1, 5, 3, 4>, f32, _> = x.try_matmul(y);
    /// ```
    fn try_matmul(
        self,
        rhs: Tensor<(usize, usize, RightK, usize), E, D, R>,
    ) -> Result<Self::Output, Self::Err> {
        // assert_eq!(self.shape.0, rhs.shape.0);
        // assert_eq!(self.shape.1, rhs.shape.1);
        // assert_eq!(self.shape.3, rhs.shape.2);
        self.shape.assert_dim_eq();
        try_binary_op(self, rhs, D::forward, D::backward)
    }
}

Will pick it up tomorrow from here, if you have any quick hints, they are always welcome.

coreylowman commented 1 year ago

Hmm want to put up a WIP PR of your changes? Will probably be easier to suggest things if I see what the existing changes are

Queries and Keys have dynamic dimensions, but not sure how to deal with them yet. Any insight why they have to be dynamic?

Its so Transformers doesn't have to be on nightly. If we make them Const then we would be using generic const expressions when we reshape from V -> V / H. We know as long as V % H == 0 this holds, and that's what the existing TYPE_CHECK for MultiHeadAttention is for.

kstavro commented 1 year ago

Yep, will open the PR when I get back home. Thanks!

kstavro commented 1 year ago

I think I found a/the solution: probably the (only?) way to do it is by having Static and Dynamic MatMul Kernels as well as static and dynamic TryMatMuls. Same logic, only different associated types. This way we can know which dims are static and which not, so that we can check dimensions either statically or dynamically.

The only problem is that every possible combination of tensor dimensions might require their own Trait because of conflicts, as I can't seem to be able to pack the same behavior in the same trait, eg for Tensor(B, usize, S1, usize) x Tensor(B, usize, usize, S2) and Tensor(B, usize, S1, usize) x Tensor(B, usize, usize, usize). Which are the same thing in terms of what or whether you check dims, but rust doesn't let me implement both with the same trait due to conflicts. So it might be a bit more than double the code.

Maybe with an enum or a new type somehow signifying that we don't care if the dim is static or dynamic, in the sense of Tensor(B, DontCareifStaticOrDynamic, S1, usize) x Tensor(B, DontCareifStaticOrDynamic, usize, DontCareifStaticOrDynamic) (not sure if that is possible in rust).

I postponed the PR so that I can bring it to a state that makes a bit more sense. It will still be WIP even if I manage to make everything work, hehe. Will pick it up again tomorrow.

kstavro commented 1 year ago

Ok, with Static and Dynamic variations of the ops and the kernels, I made it at least work. The errors look like above, which is nice.

The main problem is that the above approach sort of needs duplicate code for all the matmul kernels and traits:

pub trait StaticMatMatKernel<E: Dtype>: DeviceStorage {
    fn forward<M: Dim, LeftK: Dim, RightK: Dim, N: Dim>(
        &self,
        lhs: &Tensor<(M, LeftK), E, Self>,
        rhs: &Tensor<(RightK, N), E, Self>,
    ) -> Result<Tensor<(M, N), E, Self>, Self::Err>
    where
        (M, LeftK): MulStaticDimCheck<(RightK, N)>;

    fn backward<M: Dim, LeftK: Dim, RightK: Dim, N: Dim>(
        &self,
        lhs: &Tensor<(M, LeftK), E, Self>,
        grad_lhs: &mut Self::Vec<E>,
        rhs: &Tensor<(RightK, N), E, Self>,
        grad_rhs: &mut Self::Vec<E>,
        grad_out: &Self::Vec<E>,
    ) -> Result<(), Self::Err>
    where
        (M, LeftK): MulStaticDimCheck<(RightK, N)>;
}

impl<M: Dim, LeftK: Dim, RightK: Dim, N: Dim, E: Dtype, D: StaticMatMatKernel<E>, T, R>
    TryStaticMatMul<Tensor<(RightK, N), E, D, R>> for Tensor<(M, LeftK), E, D, T>
where
    T: Tape<E, D> + Merge<R>,
    R: Tape<E, D>,
    (M, LeftK): MulStaticDimCheck<(RightK, N)>,
{
    type Output = Tensor<(M, N), E, D, T>;
    /// ```compile_fail
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// let x: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
    /// let y: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
    /// let _: Tensor<Rank2<3, 4>, f32, _> = x.try_matmul(y);
    /// ```
    fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result<Self::Output, Self::Err> {
        // assert_eq!(self.shape.1.size(), rhs.shape.0.size());
        self.shape.assert_dim_eq();
        // println!(
        //     "Left {:?} Right {:?}",
        //     self.shape.1.size(),
        //     rhs.shape.0.size()
        // );
        try_binary_op(self, rhs, D::forward, D::backward)
    }
}

\\ whereas for dynamic practically same code with different associated types

pub trait DynamicMatMatKernel<E: Dtype>: DeviceStorage { // the main dif are the types being usize
    fn forward<M: Dim, N: Dim>(
        &self,
        lhs: &Tensor<(M, usize), E, Self>,
        rhs: &Tensor<(usize, N), E, Self>,
    ) -> Result<Tensor<(M, N), E, Self>, Self::Err>
    where
        (M, usize): MulDynamicDimCheck<(usize, M)>;

    fn backward<M: Dim, N: Dim>(
        &self,
        lhs: &Tensor<(M, usize), E, Self>,
        grad_lhs: &mut Self::Vec<E>,
        rhs: &Tensor<(usize, N), E, Self>,
        grad_rhs: &mut Self::Vec<E>,
        grad_out: &Self::Vec<E>,
    ) -> Result<(), Self::Err>
    where
        (M, usize): MulDynamicDimCheck<(usize, N)>;    // and the dim check being dynamic
}

impl<M: Dim, N: Dim, E: Dtype, D: DynamicMatMatKernel<E>, T, R>
    TryDynamicMatMul<Tensor<(usize, N), E, D, R>> for Tensor<(M, usize), E, D, T>
where
    T: Tape<E, D> + Merge<R>,
    R: Tape<E, D>,
    (M, usize): MulDynamicDimCheck<(usize, N)>,
{
    type Output = Tensor<(M, N), E, D, T>;
    /// ```compile_fail
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// let x: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
    /// let y: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
    /// let _: Tensor<Rank2<3, 4>, f32, _> = x.try_matmul(y);
    /// ```
    fn try_dynamic_matmul(
        self,
        rhs: Tensor<(usize, N), E, D, R>,
    ) -> Result<Self::Output, Self::Err> {
        // assert_eq!(self.shape.1.size(), rhs.shape.0.size());
        // self.shape.assert_dim_eq();
        // println!(
        //     "Left {:?} Right {:?}",
        //     self.shape.1.size(),
        //     rhs.shape.0.size()
        // );
        try_binary_op(self, rhs, D::forward, D::backward)
    }
}

MulStaticDimCheck being a compile time check as we had it above, while MulDynamicDimCheck being a runtime check for usize.

Actually, I even had to do a second Dynamic family of traits and kernels to make attention work, because of conflicts when trying to use the same trait:

impl<B: Dim, S1: Dim, S2: Dim, E: Dtype, D, T, R>
    TryDynamicMatMul<Tensor<(B, usize, usize, S2), E, D, R>>
    for Tensor<(B, usize, S1, usize), E, D, T>

impl<B: Dim, S1: Dim, S2: Dim, E: Dtype, D, T, R>
    TryDynamicMatMul1<Tensor<(B, usize, S2, usize), E, D, R>> // TryDynamicMatMul1 here, second trait
    for Tensor<(B, usize, S1, S2), E, D, T>

Not sure how much of this is me being bad at Rust or Rust being a bit too limiting in the fact I need to define two variations of dynamic to be able to deal with different combinations of types.

How do you feel about the above? Is there any cleverer way to implement this?

My other idea about having Dim as an enum instead of a trait in the spirit the following:

#[derive(Debug, Copy, Clone)]
pub enum Dim {
    Static,
    Dynamic,
}

while it might give more flexibility in knowing whether we use static or dynamic dims, it might require big structural changes (and not even 100% sure whether it will work the way I picture it), so I at least wanted to see if the greedier approach with "duplicating" all the kernels and matmuls to static and dynamic variants works (it seems it does and gets the errors right, but it is not the prettiest thing I have written or seen).

coreylowman commented 1 year ago

Hmm yeah splitting traits is not ideal, but I'm sure we can get around having to do that! Want to open up a PR and I can review there?

kstavro commented 1 year ago

Yep, will clean things up a bit and push a PR tomorrow morning.