tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.96k stars 445 forks source link

Add a PixelShuffle implementation like in pytorch #2464

Open Rick-29 opened 1 week ago

Rick-29 commented 1 week ago

In PyTorch there is a module called PixelShuffle . I created as mall implementation of it that currently only supports 4D Tensors (I tried to follow the implementation format of the library) and wanted to share it to check if someday it could be added to the crate. This is the full code:

use burn::{config::Config, module::Module, prelude::Backend, tensor::Tensor};

#[derive(Config, Debug)]
pub struct PixelShuffleConfig {
    #[config(default = "2")]
    upscale_factor: usize
}

#[derive(Module, Debug, Clone)]
pub struct PixelShuffle {
    upscale_factor: usize
}

impl PixelShuffleConfig {
    pub fn init(&self) -> PixelShuffle {
        PixelShuffle { upscale_factor: self.upscale_factor }
    }
}

impl PixelShuffle {
    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {        
        let mut dims = input.dims();
        dims.reverse();
        let c = dims[2];
        let h = dims[1];
        let w = dims[0];
        if c % (self.upscale_factor * self.upscale_factor) != 0 {
            panic!("pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of upscale_factor, but input.size(-3)={c} is not divisible by {}", self.upscale_factor * self.upscale_factor)
        }
        let oc = c / (self.upscale_factor * self.upscale_factor);
        let oh = h * self.upscale_factor;
        let ow = w * self.upscale_factor;

        let x = input.reshape([dims[3], oc, self.upscale_factor, self.upscale_factor, h, w]);  
        let x = x.permute([0, -5, -2, -4, -1, -3]);
        x.reshape([dims[3], oc, oh, ow])
    }
}

And here are some test that I made comparing the outputs with the ones from the pytorch implementation:

#[cfg(test)]
mod tests {
    use burn::backend::Wgpu;

    use super::*;

    #[test]
    fn test_pixel_shuffle() {
        let shuffle = PixelShuffle { upscale_factor: 3 };
        let tensor1 = Tensor::<Wgpu, 4>::random([1, 9, 4, 4], burn::tensor::Distribution::Default, &Default::default());
        let tensor1_shuffle = shuffle.forward(tensor1);
        dbg!(tensor1_shuffle.dims());
        assert_eq!([1, 1, 12, 12], tensor1_shuffle.dims());

        let tensor2 = Tensor::<Wgpu, 4>::random([1, 18, 7, 5], burn::tensor::Distribution::Default, &Default::default());
        let tensor2_shuffle = shuffle.forward(tensor2);
        dbg!(tensor2_shuffle.dims());
        assert_eq!([1, 2, 21, 15], tensor2_shuffle.dims());

        let tensor3 = Tensor::<Wgpu, 4>::random([128, 45, 33, 7], burn::tensor::Distribution::Default, &Default::default());
        let tensor3_shuffle = shuffle.forward(tensor3);
        dbg!(tensor3_shuffle.dims());
        assert_eq!([128, 5, 99, 21], tensor3_shuffle.dims());
    }

    #[test]
    #[should_panic]
    fn test_pixel_shuffle_panic() {
        let shuffle = PixelShuffle { upscale_factor: 3 };
        let tensor1 = Tensor::<Wgpu, 4>::random([128, 46, 33, 7], burn::tensor::Distribution::Default, &Default::default());
        let tensor1_shuffle = shuffle.forward(tensor1);
        dbg!(tensor1_shuffle.dims());
    }
}

The original implementation in c++ is here

laggui commented 1 week ago

We're always open to PRs! Feel free to open one to add this to the modules in burn-core 🙂

Rick-29 commented 1 week ago

How do i do that? Sorry, I am really new to creating pull requests

laggui commented 1 week ago

No worries! This can be done from a fork of the repository.

Check out the details here: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork