nitin-rathi / hybrid-snn-conversion

Training spiking networks with hybrid ann-snn conversion and spike-based backpropagation
https://openreview.net/forum?id=B1xSperKvH
94 stars 24 forks source link

Forward Function #5

Open tehreemnaqvi opened 4 years ago

tehreemnaqvi commented 4 years ago

Hi, I am referring to your code and have some questions regarding forward function.

I am trying to apply VGG11 on CIFAR-10 dataset.

I applied the Average pooling layer after 2 Convolutional layers.

But in forward function, I am giving the spike of the first conv layer to the second conv layer and then applying avg pool layer to the second spike of conv layer. In this case, my code is not giving any output and error.

I don't know, what's wrong with this forward function and couldn't fix it.

Can you please tell me what should I do?

Your help will be appreciated.

Below is my code:

class SCNN(nn.Module): def init(self): super(SCNN, self).init()

in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]

    self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
    self.max_pool1 = nn.AvgPool2d(kernel_size=2)

    self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
    self.max_pool2 = nn.AvgPool2d(kernel_size=2)

    self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
    self.max_pool3 = nn.AvgPool2d(kernel_size=2)

    self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv8 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
    self.max_pool4 = nn.AvgPool2d(kernel_size=2)

    # self.conv8 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
    # self.max_pool5 = nn.MaxPool2d(kernel_size=2)

    self.fc0 = nn.Linear(512 * 2 * 2, 4096, bias=False)
    self.fc1 = nn.Linear(4096, 4096, bias=False)
    self.fc2 = nn.Linear(4096, 10)
    # self.fc2 = nn.Linear(1024, 10)

def forward(self, input, time_window=20):
    # batch_size, ch, w, h = input.size()

    c1_mem = c1_spike = Variable(torch.zeros(batch_size, 64, 32, 32).cuda(), requires_grad=False)
    #print(c1_mem.shape)
    c2_mem = c2_spike = Variable(torch.zeros(batch_size, 128, 32, 32).cuda(), requires_grad=False)
    #print(c2_mem.shape)

    c3_mem = c3_spike = Variable(torch.zeros(batch_size, 256, 16, 16).cuda(), requires_grad=False)
    #print(c3_mem.shape)
    c4_mem = c4_spike = Variable(torch.zeros(batch_size, 256, 16, 16).cuda(), requires_grad=False)
    #print(c4_mem.shape)

    c5_mem = c5_spike = Variable(torch.zeros(batch_size, 512, 8, 8).cuda(), requires_grad=False)
    #print(c5_mem.shape)
    c6_mem = c6_spike = Variable(torch.zeros(batch_size, 512, 8, 8).cuda(), requires_grad=False)
    #print(c6_mem.shape)

    c7_mem = c7_spike = Variable(torch.zeros(batch_size, 512, 4, 4).cuda(), requires_grad=False)
    #print(c7_mem.shape)
    c8_mem = c8_spike = Variable(torch.zeros(batch_size, 512, 4, 4).cuda(), requires_grad=False)
    #print(c8_mem.shape)

    h1_mem = h1_spike = h1_sumspike = Variable(torch.zeros(batch_size, 4096).cuda(), requires_grad=False)        #print(h1_mem.shape)
    h2_mem = h2_spike = h2_sumspike = Variable(torch.zeros(batch_size, 4096).cuda(), requires_grad=False)
   # print(h2_mem.shape)
    h3_mem = h3_spike = h3_sumspike = torch.zeros(batch_size, 10, device=device)    #print(h3_mem.shape)

    for step in range(time_window):  # simulation time steps
        x = input > torch.rand(input.size(), device=device)  # prob. firing

        c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)

        #x = F.avg_pool2d(c1_spike, 2)

        c2_mem, c2_spike = mem_update(self.conv2, c1_spike, c2_mem, c2_spike)
        #x = torch.cat(c1_spike,c2_spike)

        x = self.max_pool1(c2_spike)

        c3_mem, c3_spike = mem_update(self.conv3, x, c3_mem, c3_spike)

        #x = F.avg_pool2d(c3_spike, 2)

        c4_mem, c4_spike = mem_update(self.conv4, c3_spike, c4_mem, c4_spike)
        #x = torch.cat(c3_spike, c4_spike)

        x = self.max_pool2(c4_spike)

        c5_mem, c5_spike = mem_update(self.conv5, x, c5_mem, c5_spike)

        #x = F.avg_pool2d(c5_spike, 2)

        c6_mem, c6_spike = mem_update(self.conv6, c5_spike, c6_mem, c6_spike)
        #x = torch.cat(c5_spike, c6_spike)

        x = self.max_pool3(c6_spike)

        c7_mem, c7_spike = mem_update(self.conv7, x, c7_mem, c7_spike)

        #x = F.avg_pool2d(c7_spike, 2)

        c8_mem, c8_spike = mem_update(self.conv8, c7_spike, c8_mem, c8_spike)
        #x = torch.cat(c7_spike, c8_spike)

        x = self.max_pool4(c8_spike)

        x = x.view(x.size(0), -1)

        h1_mem, h1_spike = mem_update(self.fc0, x, h1_mem, h1_spike)
        h1_sumspike += h1_spike
        h2_mem, h2_spike = mem_update(self.fc1, h1_spike, h2_mem, h2_spike)
        h2_sumspike += h2_spike

        h3_mem, h3_spike = mem_update(self.fc2, h2_spike, h3_mem, h3_spike)
        h3_sumspike += h3_spike

    outputs = h3_sumspike / time_window
    return outputs
nitin-rathi commented 4 years ago

The code you have posted is not part of this repository