Closed jhunter533 closed 4 weeks ago
Thanks for your question!
"Forward Code 2" is the correct implementation. After setting step_mode="m"
, neuron modules (i.e. spikingjelly.activation_based.BaseNode
) will execute the for loop over T
inside the forward()
method. The internal logic is as follows:
class BaseNode(base.MemoryModule):
def single_step_forward(self, x: torch.Tensor):
...
return spike
def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
y_seq = []
...
for t in range(T):
y = self.single_step_forward(x_seq[t])
y_seq.append(y)
...
...
return torch.stack(y_seq)
def forward(self, *args, **kwargs):
if self.step_mode == 's':
return self.single_step_forward(*args, **kwargs)
elif self.step_mode == 'm':
return self.multi_step_forward(*args, **kwargs)
else:
raise ValueError(self.step_mode)
Hence, "Forward Code 1" actually runs a double loop (T * T). That's why it has a significantly longer execution time.
Thank you very much for your response.
Issue type
SpikingJelly version
0.0.0.0.15
Description
I had question in regards to the implementation of multi-step mode in spiking jelly. Would the correct implementation of the forward call when setting step mode to 'multi' be the code for forward call 1 or 2 below. They both run and save correctly the main difference I observe is 1 takes significantly longer.
Minimal code to reproduce the error/bug Main Class:
Forward Code 1:
Forward Code 2: