fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.39k stars 242 forks source link

Question on multi step mode implementation #586

Closed jhunter533 closed 4 weeks ago

jhunter533 commented 1 month ago

Issue type

SpikingJelly version

0.0.0.0.15

Description

I had question in regards to the implementation of multi-step mode in spiking jelly. Would the correct implementation of the forward call when setting step mode to 'multi' be the code for forward call 1 or 2 below. They both run and save correctly the main difference I observe is 1 takes significantly longer.

Minimal code to reproduce the error/bug Main Class:

import spikingjelly
class Network(nn.Module):
    def __init__(self,num_states,num_actions,action_bound,hidden_dim,hidden_dim2):
        super(Network,self).__init__()
        self.num_states=num_states
        self.num_actions=num_actions
        self.action_bound=action_bound
        self.hidden_dim=hidden_dim
        self.hidden_dim2=hidden_dim2
        self.T=8

        self.L1=nn.Sequential(
            nn.Linear(self.num_states,self.hidden_dim),
            neuron.LIFNode(surrogate_function=surrogate.Sigmoid(),backend='cupy'),
            nn.Linear(self.hidden_dim,self.hidden_dim2),
            neuron.LIFNode(surrogate_function=surrogate.Sigmoid(),backend='cupy'),
        )
        self.L3=nn.Sequential(
            nn.Linear(self.hidden_dim2,self.num_actions),
            neuron.NonSpikingLIFNode()
        )
        self.L4=nn.Sequential(
            nn.Linear(self.hidden_dim2,self.num_actions),
            neuron.NonSpikingLIFNode() 
        )
        functional.set_step_mode(self.L1,step_mode='m')
        functional.set_step_mode(self.L3,step_mode='m')
        functional.set_step_mode(self.L4,step_mode='m')

Forward Code 1:

    def forward(self,state):
        state=state.unsqueeze(0).repeat(self.T,1,1)
        for t in range(self.T):
            x=self.L1(state)
            mu=self.L3(x)
            log_std=self.L4(x)         
        return mu,log_std

Forward Code 2:

    def forward(self,state):
        state=state.unsqueeze(0).repeat(self.T,1,1)
        x=self.L1(state)
        mu=self.L3(x)
        log_std=self.L4(x)         
        return mu,log_std
AllenYolk commented 4 weeks ago

Thanks for your question!

"Forward Code 2" is the correct implementation. After setting step_mode="m", neuron modules (i.e. spikingjelly.activation_based.BaseNode) will execute the for loop over T inside the forward() method. The internal logic is as follows:

class BaseNode(base.MemoryModule):
    def single_step_forward(self, x: torch.Tensor):
        ...
        return spike

    def multi_step_forward(self, x_seq: torch.Tensor):
        T = x_seq.shape[0]
        y_seq = []
        ...
        for t in range(T):
            y = self.single_step_forward(x_seq[t])
            y_seq.append(y)
            ...
        ...
        return torch.stack(y_seq)

    def forward(self, *args, **kwargs):
        if self.step_mode == 's':
            return self.single_step_forward(*args, **kwargs)
        elif self.step_mode == 'm':
            return self.multi_step_forward(*args, **kwargs)
        else:
            raise ValueError(self.step_mode)

Hence, "Forward Code 1" actually runs a double loop (T * T). That's why it has a significantly longer execution time.

jhunter533 commented 4 weeks ago

Thank you very much for your response.