jafioti / luminal

Deep learning at the speed of light.
https://luminalai.com
Apache License 2.0
1.45k stars 90 forks source link

[Question] implementing Conv3d #50

Closed NewBornRustacean closed 4 months ago

NewBornRustacean commented 5 months ago

Hello! Thanks for this awesome project! @jafioti

Recently, I'm working on Conv3d(#28 #29) and I have a question:

Is it OK to define Axes7 and R7? If to do so, there might be a shot-gun surgery I guess..

My draft is here:

impl<
    const CHANNELS_IN: usize,
    const CHANNELS_OUT: usize,
    const KERNELX: usize,
    const KERNELY: usize,
    const KERNELZ: usize,
    const STRIDEX: usize,
    const STRIDEY: usize,
    const STRIDEZ: usize,
    const DILATIONX: usize,
    const DILATIONY: usize,
    const DILATIONZ: usize,
    const CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ: usize,
> Conv3D<
    CHANNELS_IN,
    CHANNELS_OUT,
    KERNELX,
    KERNELY,
    KERNELZ,
    STRIDEX,
    STRIDEY,
    STRIDEZ,
    DILATIONX,
    DILATIONY,
    DILATIONZ,
    CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ,
>
{
    pub fn forward<
        const DIMX_IN: usize,
        const DIMY_IN: usize,
        const DIMZ_IN: usize,
        const DIMX_OUT: usize,
        const DIMY_OUT: usize,
        const DIMZ_OUT: usize,
        const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
    >(
        &self,
        input: GraphTensor<R4<CHANNELS_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>,
    ) -> GraphTensor<R4<CHANNELS_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>> {
        let input_pooled = input
            .pool_last_dim::<R5<CHANNELS_IN, DIMX_IN, DIMY_OUT, DIMZ_OUT, KERNELY>>(
                KERNELY.into(),
                STRIDEY.into(),
                DILATIONY
            )
            .permute::<_, Axes5<0, 2, 3, 4, 1>>()
            .pool_last_dim::<R6<CHANNELS_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
                KERNELX.into(),
                STRIDEX.into(),
                DILATIONX
            )
            .permute::<_, Axes6<0, 5, 2, 3, 4, 1>>()
            .pool_last_dim::<R7<CHANNELS_IN, DIMZ_OUT, KERNELZ, DIMX_OUT, KERNELX, DIMY_OUT, KERNELY>>(
                KERNELZ.into(),
                STRIDEZ.into(),
                DILATIONZ
            )
            .permute::<_, Axes7<0, 6, 2, 3, 4, 5, 1>>()
            .reshape::<R2<CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ, DIMX_TIMES_DIMY_DIMZ_OUT>>();

        self.weight
            .matmul(input_pooled)
            .reshape::<R4<CHANNELS_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
    }
}

I'm not sure I'm going right direction, so any comments will be appreciated!

Have a nice day :)

jafioti commented 5 months ago

Great work! It looks correct to me on first pass. Is there any way to implement it without requiring a 7D tensor? Like maybe merging some of the dimensions before doing the last pool_last_dim? Ideally we keep the max tensor dims at 6 (in order to keep the shapetracker small because it's stored on the stack).

If not, I'll look into adding a 7th dimension

NewBornRustacean commented 5 months ago

Thanks! I' ll look into it! (I'm quite a newbie just like my nick name, this is really helpful! )