miladmozafari / SpykeTorch

High-speed simulator of convolutional spiking neural networks with at most one spike per neuron.
GNU General Public License v3.0
382 stars 101 forks source link

No learning with STDP #12

Closed ggoupy closed 2 years ago

ggoupy commented 2 years ago

Hello. I am new to SpykeTorch. I am working on an anomaly detection project and I would like to learn features from spectrograms using a Convolutional SNN. I have been struggling so far for the past week with STDP. My model achieves decent performance (~ 70% AUC) but does not learn or very few (gain of 2-5% AUC). I have tried a lot of things : parameter tuning (number of winners, firing threshold, inhibition radius) but also adaptive learning rate, adaptive firing threshold, etc. I don't understand where the problem comes from because I know I can obtain > 90% AUC with a regular CNN.

My pipeline is : signal -> spectrogram (MFSC) -> CSNN -> ML outlier detection classifier

So I have several questions :

Do you have any intuition about something I am doing wrong ?

Here is my model, I can also send you my whole code if you have the motivation to check it out.

class CSNN(nn.Module):
    def __init__(self,
            input_shape,
        ):
        super(CSNN, self).__init__()

        self.ctx = {}

        out_channels = 50
        kernel_height = 7
        in_nb_spike_bins, in_channels, in_frames, in_freqs = input_shape  
        output_height = in_frames - kernel_height

        self.conv = snn.Convolution(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_height,in_freqs), weight_mean=0.8, weight_std=0.05)
        self.stdp = snn.STDP(self.conv, learning_rate = (0.004, -0.003))
        self.pool = snn.Pooling(kernel_size = (4,1), stride = (4,1), padding = 0)
        self.firing_thr = 26
        self.nb_winners = 1
        self.inhib_rad = 0

        self.max_ap = Parameter(torch.Tensor([0.15]))

        self.mean_pot = 0
        self.counter = 0

    def get_thr(self):
        return self.mean_pot / self.counter

    def forward(self, input):
        input = input.float()
        if self.training:
            pot = self.conv(input)
            spk,pot = sf.fire(pot, self.firing_thr, return_thresholded_potentials=True)
            #pot = sf.pointwise_inhibition(pot)
            #spk = pot.sign() #remove spk where pot is now null
            winners = sf.get_k_winners(pot, kwta=self.nb_winners, inhibition_radius=self.inhib_rad, spikes=spk)
            self.save_stdp_data(input, pot, spk, winners)
        else:
            pot = self.conv(input)
            pot = self.pool(pot)
            self.mean_pot += pot.mean()
            self.counter += 1
            spk = sf.fire(pot, self.firing_thr)
            return spk, pot

    def save_stdp_data(self, input_spikes, potentials, output_spikes, winners):
        self.ctx['input_spikes'] = input_spikes
        self.ctx['potentials'] = potentials
        self.ctx['output_spikes'] = output_spikes
        self.ctx['winners'] = winners

    def update_stdp(self):
        self.stdp(self.ctx['input_spikes'], self.ctx['potentials'], self.ctx['output_spikes'], self.ctx['winners'])

    def update_learning_rate(self):
        ap = torch.tensor(self.stdp.learning_rate[0][0].item(), device=self.stdp.learning_rate[0][0].device) * 2
        ap = torch.min(ap, self.max_ap)
        an = ap * -0.75
        self.stdp.update_all_learning_rate(ap.item(), an.item())

Notes :

miladmozafari commented 2 years ago

Hello, Thank you for your interest in our work. I didn't clearly understand the situation. If the model is not learning, how could it reach ~ 70% AUC?

miladmozafari commented 2 years ago

about your questions:

ggoupy commented 2 years ago

Hello, thank you for your answer !

I think that the model reaches ~ 72% AUC because the Convolutional SNN might act as an encoder that reduce the input dimensionality.

However, I do have this AUC when I use a potential-based vector for readout, which is a vector of size (out_c out_h out_w) where each value is the mean over all timesteps. For the spike-based vector, I have an AUC of 50% without training and around 68% with training. So there is such a learning process even if it leads to worse performance.

About STDP, if I choose one output neuron winner, it means that weights of the feature map of the winner are updated, using input values used (the convolution patch) to compute the potential of the winner right ?

miladmozafari commented 2 years ago

To solve the STDP learning problem, I think I need to play with the code to find its source. About the winner, you are right. The weights are shared, and they will be updated based on the winner neuron's inputs (and outputs).

ggoupy commented 2 years ago

Alright, I moved to something different now but thanks for your answers !