fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

Find the optimal mini batch size for seq_to_ann_container #183

Closed fangwei123456 closed 2 years ago

fangwei123456 commented 2 years ago

We use T * N as the batch size, but a too larger batch size may not be the fastest:

import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven.cu_kernel_opt import cal_fun_t

def forward_backward(x, net: nn.Module, optimizer: torch.optim.Optimizer, bs):
    assert net.training
    optimizer.zero_grad()
    y = []
    for i in range(x.shape[0] // bs):
        y.append(net(x[i * bs: (i + 1) * bs]))
    y = torch.cat(y)
    assert y.shape[0] == x.shape[0]
    y.sum().backward()
    optimizer.step()

device = 'cuda:0'
conv = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
conv.to(device)
optimizer = torch.optim.SGD(conv.parameters(), lr=1e-6)
repeats = 32
x = torch.rand([8192, 128, 32, 32], device=device)
for bs in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
    print(bs, cal_fun_t(repeats, device, forward_backward, x, conv, optimizer, bs))
2 0.7070100581868246
4 0.4419617693752116
8 0.3370353510631503
16 0.3003204499373169
32 0.28969705309509663
64 0.2803983489379789
128 0.2786591427784515
256 0.27968712543815855
512 0.2802094193416451
1024 0.585363641591357
2048 0.5728413845918112
4096 0.5698013973433262
8192 0.5855700465936025
fangwei123456 commented 2 years ago
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven.model import spiking_resnet
from spikingjelly.clock_driven import neuron, functional
from spikingjelly.clock_driven.cu_kernel_opt import cal_fun_t
from spikingjelly import configure

def forward_backward(x, net: nn.Module, optimizer: torch.optim.Optimizer):
    optimizer.zero_grad()
    y = net(x)
    y.sum().backward()
    optimizer.step()
    functional.reset_net(net)

device = 'cuda:0'
net = spiking_resnet.multi_step_spiking_resnet50(T=4, multi_step_neuron=neuron.MultiStepIFNode, backend='cupy')
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-6)
repeats = 64
x = torch.rand([16, 3, 224, 224], device=device)

configure.mini_batch_benchmark = True
forward_backward(x, net, optimizer)
print(cal_fun_t(repeats, device, forward_backward, x, net, optimizer))
configure.mini_batch_benchmark = False
print(cal_fun_t(repeats, device, forward_backward, x, net, optimizer))
365.29833030700684
331.5680847167969

I find that using the optimized mini_batch_size for each conv/bn layer will reduce the speed of the whole network. So, I have to revert the commit https://github.com/fangwei123456/spikingjelly/commit/94e8fc9b319f681fca989bda2f0d558f4962b2a3 .