fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.38k stars 239 forks source link

Example code for semantic segmentation? #177

Closed ChidanandKumarKS closed 2 years ago

ChidanandKumarKS commented 2 years ago

@fangwei123456 Any experiments relevant to semantic segmentation also was carried out

Regards K S Chidanand Kumar

fangwei123456 commented 2 years ago

Hi, I am not familiar with semantic segmentation. So, there is no code in SpikingJelly about semantic segmentation. But I believe the following papers can be reproducted by SpikingJelly: Beyond Classification: Directly Training Spiking Neural Networks for Semantic Segmentation A Spiking Neural Network for Image Segmentation

fangwei123456 commented 2 years ago

StereoSpike: Depth Learning with a Spiking Neural Network is a similar work. It uses SpikingJelly to implement, and its codes are available: https://github.com/urancon/StereoSpike .

ChidanandKumarKS commented 2 years ago

@fangwei123456 Thanks for sharing relevant infos. I actually referred PriyaPanda paper, it seems they are not using your repo and there code is badly written Also i tried to use stereospike which is using your repo to do image reconstruction from event camera rather than optical flow. But iam unable to get reconstruct image using event camera as input.

Requesting you to help me in this regard.

Below is the encoder and decoder written with spikingjelly repo

class StereoSpike(NeuromorphicNet):
    """
    Baseline model, with which we report state-of-the-art performances in the second version of our paper.

    - all neuron potentials must be reset at each timestep
    - predict_depth layers do have biases, but it is equivalent to remove them and reset output I-neurons to the sum
           of all 4 biases, instead of 0.
    """
    def __init__(self, surrogate_function=surrogate.ATan(), detach_reset=True, v_threshold=1.0, v_reset=0.0, multiply_factor=1.):
        super().__init__(surrogate_function=surrogate_function, detach_reset=detach_reset)

        # bottom layer, preprocessing the input spike frame without downsampling
        self.bottom = nn.Sequential(
            nn.Conv2d(in_channels=5, out_channels=32, kernel_size=5, stride=1, padding=2, bias=False),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )

        # encoder layers (downsampling)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2, bias=False),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2, bias=False),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2, bias=False),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2, bias=False),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )

        # residual layers
        self.bottleneck = nn.Sequential(
            SEWResBlock(512, v_threshold=self.v_th, v_reset=self.v_rst, connect_function='ADD', multiply_factor=multiply_factor),
            SEWResBlock(512, v_threshold=self.v_th, v_reset=self.v_rst, connect_function='ADD', multiply_factor=multiply_factor),
        )

        # decoder layers (upsampling)
        self.deconv4 = nn.Sequential(
            NNConvUpsampling2(in_channels=512, out_channels=256, kernel_size=3, scale_factor=2),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.deconv3 = nn.Sequential(
            NNConvUpsampling2(in_channels=256, out_channels=128, kernel_size=3, scale_factor=2),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.deconv2 = nn.Sequential(
            NNConvUpsampling2(in_channels=128, out_channels=64, kernel_size=3, scale_factor=2),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )
        self.deconv1 = nn.Sequential(
            NNConvUpsampling2(in_channels=64, out_channels=32, kernel_size=3, scale_factor=2),
            MultiplyBy(multiply_factor),
            neuron.IFNode(v_threshold=self.v_th, v_reset=self.v_rst, surrogate_function=self.surrogate_fct, detach_reset=True),
        )

        # these layers output depth maps at different scales, where depth is represented by the potential of IF neurons
        # that do not fire ("I-neurons"), i.e., with an infinite threshold.
        self.predict_depth1 = nn.Sequential(
            NNConvUpsampling2(in_channels=32, out_channels=1, kernel_size=3, scale_factor=1, bias=True),
            MultiplyBy(multiply_factor),
        )

        self.Ineurons = neuron.IFNode(v_threshold=float('inf'), v_reset=0.0, surrogate_function=self.surrogate_fct)
        self.sigmoid = nn.Sigmoid()
        self.num_encoders = 4

    def forward(self, x,pred):

        # x must be of shape [batch_size, num_frames_per_depth_map, 4 (2 cameras - 2 polarities), W, H]
        frame = x

        # data is fed in through the bottom layer
        out_bottom = self.bottom(frame)

        # pass through encoder layers
        out_conv1 = self.conv1(out_bottom)
        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3(out_conv2)
        out_conv4 = self.conv4(out_conv3)

        # pass through residual blocks
        out_rconv = self.bottleneck(out_conv4)

        # gradually upsample while concatenating and passing through skip connections
        out_deconv4 = self.deconv4(out_rconv)
        out_add4 = out_deconv4 + out_conv3
        # self.Ineurons(self.predict_depth4(out_add4))

        out_deconv3 = self.deconv3(out_add4)
        out_add3 = out_deconv3 + out_conv2
        # self.Ineurons(self.predict_depth3(out_add3))

        out_deconv2 = self.deconv2(out_add3)
        out_add2 = out_deconv2 + out_conv1
        # self.Ineurons(self.predict_depth2(out_add2))

        out_deconv1 = self.deconv1(out_add2)
        out_add1 = out_deconv1 + out_bottom
        self.Ineurons(self.predict_depth1(out_add1))
        img = self.sigmoid(self.Ineurons.v)

        return {'image': img}

    def set_init_depths_potentials(self, depth_prior):
        self.Ineurons.v = depth_prior
fangwei123456 commented 2 years ago

Do you use the MVSEC dataset? Or you use a new dataset?

ChidanandKumarKS commented 2 years ago

Iam using MVSEC dataset and also IJRR dataset. Kindly suggest how to proceed

fangwei123456 commented 2 years ago

Then you use the codes from https://github.com/urancon/StereoSpike and can not product the correct results?

ChidanandKumarKS commented 2 years ago

I tried to reuse StereoSpike codes, results are far from actual grayscale images and not even close. Kindly suggest

fangwei123456 commented 2 years ago

@urancon

ChidanandKumarKS commented 2 years ago

Results

ChidanandKumarKS commented 2 years ago

https://github.com/urancon Left is target image and Right one is predicted image. Iam using LPIPS and temporal consistency loss. Kindly suggest and help

fangwei123456 commented 2 years ago

https://github.com/urancon/StereoSpike/issues/3

ChidanandKumarKS commented 2 years ago

Issue with bug in my code