Closed fangwei123456 closed 2 years ago
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven.model import spiking_resnet
from spikingjelly.clock_driven import neuron, functional
from spikingjelly.clock_driven.cu_kernel_opt import cal_fun_t
from spikingjelly import configure
def forward_backward(x, net: nn.Module, optimizer: torch.optim.Optimizer):
optimizer.zero_grad()
y = net(x)
y.sum().backward()
optimizer.step()
functional.reset_net(net)
device = 'cuda:0'
net = spiking_resnet.multi_step_spiking_resnet50(T=4, multi_step_neuron=neuron.MultiStepIFNode, backend='cupy')
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-6)
repeats = 64
x = torch.rand([16, 3, 224, 224], device=device)
configure.mini_batch_benchmark = True
forward_backward(x, net, optimizer)
print(cal_fun_t(repeats, device, forward_backward, x, net, optimizer))
configure.mini_batch_benchmark = False
print(cal_fun_t(repeats, device, forward_backward, x, net, optimizer))
365.29833030700684
331.5680847167969
I find that using the optimized mini_batch_size for each conv/bn layer will reduce the speed of the whole network. So, I have to revert the commit https://github.com/fangwei123456/spikingjelly/commit/94e8fc9b319f681fca989bda2f0d558f4962b2a3 .
We use
T * N
as the batch size, but a too larger batch size may not be the fastest: