fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

【更新】新版基类、神经元、编码器的实现 #82

Closed fangwei123456 closed 3 years ago

fangwei123456 commented 3 years ago

开了一个新分支,实现一个新的有记忆的基类,在此基础上稍微改动神经元和编码器。文档和教程更新后关闭此issue

fangwei123456 commented 3 years ago

新的有记忆的基类主要特性是,有记忆的状态会被保存到一个专门的字典:

def register_memory(self, name: str, value):
    self._memory[name] = value
    self._memory_rv[name] = value

因此,子类继承后,不需要写reset函数了。reset函数由这个基类完成:

def reset(self):
    for key in self._memory.keys():
        self._memory[key] = self._memory_rv[key]

允许直接通过变量名访问成员,因而神经元访问电压时仍然可以用self.v,而不需要self._memory['v'],尽管这样也能访问:

def __getattr__(self, name: str):
    if name in self._memory:
        return self._memory[name]
    else:
        return super().__getattr__(name)

以往的模块,如果使用to(device)将其移动到某个GPU,则模型内的某些参数无法被直接移动,需要手动设置,现在这个问题也得以解决:

def _apply(self, fn):
    for key, value in self._memory:
        if isinstance(value, torch.Tensor):
            self._memory[key] = fn(value)

    for key, value in self._memory_rv:
        if isinstance(value, torch.Tensor):
            self._memory_rv[key] = fn(value)
    return super()._apply(fn)
fangwei123456 commented 3 years ago

@Yanqi-Chen 之前的适应性阈值神经元要重新写入吗,这个神经元在上次神经元更新后删除了,现在应该稳定了,可以重新加入

Yanqi-Chen commented 3 years ago

@Yanqi-Chen 之前的适应性阈值神经元要重新写入吗,这个神经元在上次神经元更新后删除了,现在应该稳定了,可以重新加入

可以加进来,只要Python版本吗?还是cuda也要?

fangwei123456 commented 3 years ago

@Yanqi-Chen 之前的适应性阈值神经元要重新写入吗,这个神经元在上次神经元更新后删除了,现在应该稳定了,可以重新加入

可以加进来,只要Python版本吗?还是cuda也要?

只需要python就行了

fangwei123456 commented 3 years ago

和监视器分支合并了 #84

@Grasshlw 有时间的话,在这个分支下更新一下编码器的教程。因为api略有变动,但影响不大,工作量应该很小

fangwei123456 commented 3 years ago

这次准备顺便写一些cuda数学函数,给python神经元略微加速。计划用cupy实现,开发成本很低

fangwei123456 commented 3 years ago

关于cupy和pytorch的兼容性,以及cuda内核的实现方式:

import cupy
import torch
from spikingjelly.cext import cal_fun_t

fun1 = cupy.ElementwiseKernel(
    'T x, T y',
    'T z',
    'z = x*x + 3. * y + x*y + x/y',
    'fun1',
    options=('--use_fast_math',)
)

def fun2(x, y):
    return torch.as_tensor(fun1(cupy.asarray(x), cupy.asarray(y)))

@cupy.fuse()
def fun3(x, y):
    return x ** 2 + 3. * y + x * y + x / y

fun4 = cupy.RawKernel(r'''
extern "C" __global__
void fun4(const float* x, const float* y, float* z, const int n) {
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < n)
    {
        z[index] = x[index] * x[index] + 3 * y[index] + x[index] * y[index] + x[index] / y[index];
    }
}
''', 'fun4', options=('--use_fast_math',))

def fun5(x, y):
    return x ** 2 + 3. * y + x * y + x / y

device = 'cuda:0'

x = torch.rand([1, 1], device=device)
y = torch.rand([1, 1], device=device)

fun1(x, y)

print('ElementwiseKernel', cal_fun_t(128, device, fun1, x, y))
print('standard pytorch-cupy', cal_fun_t(128, device, fun2, x, y))
print('fuse', cal_fun_t(128, device, fun3, x, y))

z = torch.zeros_like(x)

threads = 1024
blocks = (x.numel() + threads - 1) // threads

print('RawKernel', cal_fun_t(128, device, fun4, (blocks,), (threads,), (x.data_ptr(), y.data_ptr(), z.data_ptr(), z.numel())))
print('pytorch', cal_fun_t(128, device, fun5, x, y))

运行结果是

ElementwiseKernel 8.068828124998227e-05
standard pytorch-cupy 0.00012701484374998823
fuse 0.00011616718749999505
RawKernel 5.038359374998774e-05
pytorch 0.00011785546875000066

可以发现自己手写核RawKernel是最快的。 standard pytorch-cupy是cupy推荐的与pytorch交互的方式,两种数据互相转换消耗了太多时间,速度最慢。 理论上fuse会将python代码转换成cuda代码,但看起来和pytorch速度没区别。 使用ElementwiseKernel有一些风险,因为理论上ElementwiseKernel只支持numpy和cupy的数组,虽然pytorch也能不报错,但很难保证运算的正确性。

因而,之后可以考虑用RawKernel来辅助一些简单的操作了。有一点需要注意,使用tensor的data_ptr()函数传递指向gpu上数据的指针,但数据本身并非连续,因此最好在传递前必须调用contiguous函数,确保内存地址连续,在以往的cuda代码中都是这样做的,例如

https://github.com/fangwei123456/spikingjelly/blob/56982956a6113229a6ff08de98b30e4a3b7114f0/spikingjelly/cext/csrc/neuron/neuron_def.h#L2

fangwei123456 commented 3 years ago

实验测试,对单步的神经元编写cuda内核,速度与pytorch本身区别不大。多步有明显加速,因此可以考虑只写多步的cuda

fangwei123456 commented 3 years ago

和监视器分支合并了 #84

@Grasshlw 有时间的话,在这个分支下更新一下编码器的教程。因为api略有变动,但影响不大,工作量应该很小

现在涉及到api变更的地方基本都更新完了

fangwei123456 commented 3 years ago

使用cupy编写的神经元和之前的cext纯cuda神经元(cext)的对比(2080ti)

from spikingjelly import cext
from spikingjelly.cext import neuron as cext_neuron
from spikingjelly.clock_driven import neuron, surrogate, layer
import torch

def cal_forward_t(multi_step_neuron, x, repeat_times):
    with torch.no_grad():
        used_t = cext.cal_fun_t(repeat_times, x.device, multi_step_neuron, x)
        multi_step_neuron.reset()
        return used_t * 1000

def forward_backward(multi_step_neuron, x):
    multi_step_neuron(x).sum().backward()
    multi_step_neuron.reset()
    x.grad.zero_()

def cal_forward_backward_t(multi_step_neuron, x, repeat_times):
    x.requires_grad_(True)
    used_t = cext.cal_fun_t(repeat_times, x.device, forward_backward, multi_step_neuron, x)
    return used_t * 1000

device = 'cuda:0'
lif = neuron.MultiStepLIFNode(surrogate_function=surrogate.ATan(alpha=2.0), backend='torch')
lif_cupy = neuron.MultiStepLIFNode(surrogate_function=surrogate.ATan(alpha=2.0), backend='cupy')
lif_cuda_tt = cext_neuron.MultiStepLIFNode(surrogate_function='ATan', alpha=2.0)
lif.to(device)
lif_cupy.to(device)
lif_cuda_tt.to(device)
N = 2 ** 20
print('forward')
lif.eval()
lif_cupy.eval()
lif_cuda_tt.eval()
for T in [8, 16, 32, 64, 128]:
    x = torch.rand(T, N, device=device)
    print(T, cal_forward_t(lif, x, 1024), cal_forward_t(lif_cupy, x, 1024), cal_forward_t(lif_cuda_tt, x, 1024))

print('forward and backward')
lif.train()
lif_cupy.train()
lif_cuda_tt.train()
for T in [8, 16, 32, 64, 128]:
    x = torch.rand(T, N, device=device)
    print(T, cal_forward_backward_t(lif, x, 1024), cal_forward_backward_t(lif_cupy, x, 1024),
          cal_forward_backward_t(lif_cuda_tt, x, 1024))

结果

forward
8 1.9291880407763529 0.819957700514351 0.23939031871123007
16 3.8124477860037587 1.6043593723225058 0.4244983774697175
32 7.606722991113202 3.268783582825563 0.7993447834451217
64 15.16240044566075 6.8287351678009145 1.5565728017463698
128 30.319368333039165 14.061802628930309 3.170469198266801
forward and backward
8 8.103932273115788 1.6449875647595036 1.4937707446733839
16 22.193177647750417 3.2178012561416836 2.853808532563562
32 66.77314530952572 6.422930538064975 5.673481301528227
64 227.0976557265385 13.077360782517644 11.367543987034878
128 830.4208710751482 26.729976756541873 22.951312136683555

不需梯度的前向传播,当时cext专门编写了推理的内核,cupy版本没有对此区分,因此速度稍慢,但推理时间和训练时间相比太少,可以忽略;前反向传播,cext的内核不会返回中间时刻的v,而cupy会返回,进而cupy会额外使用反传到中间时刻的v的梯度来计算,因而速度比cext略慢,但没有慢太多。可以考虑给cupy也写一个无需返回中间时刻v的内核

fangwei123456 commented 3 years ago

合并了主分支 https://github.com/fangwei123456/spikingjelly/pull/85