fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.29k stars 236 forks source link

多步训练时Tensor维度消失问题 #411

Closed miaodd98 closed 1 year ago

miaodd98 commented 1 year ago

Issue type

SpikingJelly version

latest

Description

您好!我现在在尝试使用SNN转换YOLOX代码进行目标检测任务的训练,现在在训练时遇到了一个问题

SNN训练使用多步训练,网络直接读取图像为[N,C,H,W]维度,但是当输入网络最初的卷积层时,N维度直接消失,网络输入变为[C,H,W]。经过多步训练所需的增维处理后,图像维度为[T,1,C,H,W],后续网络涉及到torch.cat(dim=1)操作,此时进行连接维度确实为dim=1,报错如下:

(yolo) PS D:\codes\yolox-snn-train> python .\train.py initialize network with normal type Expected 4-dimensional input for 4-dimensional weight [32, 12, 3, 3], but got 5-dimensional input of size [4, 4, 3, 320, 320] instead Error occurs, No graph saved Start Train Epoch 1/300: 0%| | 0/2068 [00:00<?, ?it/s<class 'dict'>]Traceback (most recent call last): File ".\train.py", line 486, in fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank) File "D:\codes\yolox-snn-train\utils\utils_fit.py", line 37, in fit_one_epoch outputs = model_train(images) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\parallel\data_parallel.py", line 166, in forward return self.module(*inputs[0], *kwargs[0]) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\layer.py", line 31, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\functional.py", line 563, in multi_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\container.py", line 139, in forward input = module(input) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "D:\codes\yolox-snn-train\nets\yolo.py", line 259, in forward fpn_outs = self.backbone.forward(x) File "D:\codes\yolox-snn-train\nets\yolo.py", line 170, in forward out_features = self.backbone.forward(input) File "D:\codes\yolox-snn-train\nets\darknet.py", line 270, in forward x = self.stem(x) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "D:\codes\yolox-snn-train\nets\darknet.py", line 53, in forward return self.conv(x) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "D:\codes\yolox-snn-train\nets\darknet.py", line 79, in forward return self.act(self.bn(self.conv(x))) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\layer.py", line 171, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\functional.py", line 686, in seq_to_ann_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\conv.py", line 443, in forward return self._conv_forward(input, self.weight, self.bias) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\conv.py", line 439, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [32, 12, 3, 3], expected input[16, 3, 320, 320] to have 12 channels, but got 3 channels instead

此时连接并非在多步训练下希望的C通道(加了T之后C应为dim=2,而将torh.cat修改为dim=2时,报错如下:

(yolo) PS D:\codes\yolox-snn-train> python .\train.py initialize network with normal type Expected 4-dimensional input for 4-dimensional weight [32, 12, 3, 3], but got 5-dimensional input of size [4, 1, 12, 320, 320] instead Error occurs, No graph saved Start Train Epoch 1/300: 0%| | 0/2068 [00:00<?, ?it/s<class 'dict'>]Traceback (most recent call last): File ".\train.py", line 486, in fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank) File "D:\codes\yolox-snn-train\utils\utils_fit.py", line 37, in fit_one_epoch outputs = model_train(images) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\parallel\data_parallel.py", line 166, in forward return self.module(*inputs[0], *kwargs[0]) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\layer.py", line 31, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\functional.py", line 563, in multi_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\container.py", line 139, in forward input = module(input) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "D:\codes\yolox-snn-train\nets\yolo.py", line 259, in forward fpn_outs = self.backbone.forward(x) File "D:\codes\yolox-snn-train\nets\yolo.py", line 170, in forward out_features = self.backbone.forward(input) File "D:\codes\yolox-snn-train\nets\darknet.py", line 270, in forward x = self.stem(x) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "D:\codes\yolox-snn-train\nets\darknet.py", line 53, in forward return self.conv(x) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "D:\codes\yolox-snn-train\nets\darknet.py", line 79, in forward return self.act(self.bn(self.conv(x))) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\base.py", line 270, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\neuron.py", line 533, in multi_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\neuron.py", line 251, in multi_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\neuron.py", line 597, in single_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\neuron.py", line 242, in single_step_forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\neuron.py", line 206, in neuronal_reset RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch)

nvrtc compilation failed:

define NAN __int_as_float(0x7fffffff)

define POS_INFINITY __int_as_float(0x7f800000)

define NEG_INFINITY __int_as_float(0xff800000)

template device T maximum(T a, T b) { return isnan(a) ? a : (a > b ? a : b); }

template device T minimum(T a, T b) { return isnan(a) ? a : (a < b ? a : b); }

extern "C" global void fused_neg_add_mul_mul_add(float tspike_1, double vv_reset_2, float tv_1, float aten_add, float aten_add_1) { { float tspike_1_1 = __ldg(tspike_1 + (512 blockIdx.x + threadIdx.x) % 3276800); aten_add_1[512 blockIdx.x + threadIdx.x] = (float)((double)(0.f - tspike_1_1) + 1.0); float v = __ldg(tv_1 + (512 blockIdx.x + threadIdx.x) % 3276800); aten_add[512 blockIdx.x + threadIdx.x] = (float)((double)(0.f - tspike_1_1) + 1.0) v + (float)((double)(tspike_1_1) vv_reset_2); } }

从这两个报错可以看出来,在网络使用SNN多步训练时,确实需要对输入进行增维T,使其变成[T,N,C,H,W]形式,且在代码运行过程时,涉及到如torch.cat操作时,输入维度同上。但是在网络执行时,报错的数据维度为[N,C,H,W]形式,并没有T维度。

所以这里想请问一下,消失的N维度(batchsize)是因为什么原因出现的?加载数据的Dataset和Dataloader均使用torch内部自带的,与原始代码相同。个人认为这一消失的维度是导致所有问题的根源,恳请解答!

Minimal code to reproduce the error/bug

YOLOX框架为:https://github.com/bubbliiiing/yolox-pytorch,该代码可以直接运行

网络模型结构目前仅将所有激活函数用IFNode替代,且将nn.Conv2d和nn.Maxpool2d替换为spikingjelly的layer.Conv2d和layer.Maxpool2d

设置为T=4,batchsize(N)=8。加载数据集使用torch自带Dataset与Dataloader。

miaodd98 commented 1 year ago

补充:当设置为单步运行时,输入起初为正常的NCHW,但是到网络时已经变成CHW形式了

fangwei123456 commented 1 year ago

多步模式下要求输入shape=[T, *],第0维会被当作时间维度,不能是batch维度。 pytorch中的网络相当于SJ框架中的单步,如果你直接用它们跑多步模式会出错,需要自己手动更改层的行为

fangwei123456 commented 1 year ago

网络模型结构目前仅将所有激活函数用IFNode替代,且将nn.Conv2d和nn.Maxpool2d替换为spikingjelly的layer.Conv2d和layer.Maxpool2d

检查一下网络的模块里面有没有对tensor进行维度变换的,那些操作可能导致维度错误

miaodd98 commented 1 year ago

网络模型结构目前仅将所有激活函数用IFNode替代,且将nn.Conv2d和nn.Maxpool2d替换为spikingjelly的layer.Conv2d和layer.Maxpool2d

检查一下网络的模块里面有没有对tensor进行维度变换的,那些操作可能导致维度错误

网络里是存在对tensor进行维度上的操作torch.cat,通过单步调试发现了时间维度的问题,我再挨个调一下单步和多步模型的网络参数设置。

miaodd98 commented 1 year ago

对模型运行修改为单步模式以后运行正常,其中出现了因pytorch版本与RTX4070显卡冲突出现的问题,换用最新版本pytorch nightly已解决。