fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.32k stars 237 forks source link

SpikingLSTM中前后时刻的隐状态接收问题 #544

Open Seazoned opened 4 months ago

Seazoned commented 4 months ago

Read before creating a new issue

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

SpikingJelly version

0.0.0.0.14

Description

在处理序列信息时需要将前一时刻snnLSTM的隐状态值作为下一时刻的snnLSTM的隐状态初始值,但是这样做之后会报错。 代码如下。

Minimal code to reproduce the error/bug

import torch
import torch.nn as nn
from spikingjelly.activation_based import rnn, neuron, layer, surrogate

T = 6
h_dim = 32
batch = 5

x = torch.randn([T, batch, h_dim])
lstm = rnn.SpikingLSTM(32, 32, 1)
states = None
for t in range(8):
    out, states = lstm(x, states)

以下是报错信息

Traceback (most recent call last): File "E:\SNN\SNN_LSTM\SNN_LSTM\encoder.py", line 80, in out, states = lstm(x, states) File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "E:\anaconda3\envs\social\lib\site-packages\spikingjelly\activation_based\rnn.py", line 473, in forward new_states_list[:, 0] = torch.stack(self.cells[0](x[t], states_list[:, 0])) File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "E:\anaconda3\envs\social\lib\site-packages\spikingjelly\activation_based\rnn.py", line 685, in forward i, f, g, o = torch.split(self.surrogate_function1(self.linear_ih(x) + self.linear_hh(h)), ValueError: not enough values to unpack (expected 4, got 1)

Seazoned commented 4 months ago

也许我应该换一种描述方式: pytorch的lstm输入张量的size为(序列长度,batch_size, 特征长度) SpikingJelly中的lstm输入张量的维度是(T, batch_size, 特征长度),其中T按照我的理解是脉冲序列的时间步长,也就是说snnlstm并没有原来的“序列长度”这个维度,因此在输入一定长度的序列时,我才想通过循环获得一个通过lstm的隐状态。 我的理解有误吗?恳请赐教