miladmozafari / SpykeTorch

High-speed simulator of convolutional spiking neural networks with at most one spike per neuron.
GNU General Public License v3.0
381 stars 100 forks source link

Adding and Training a R-STDP Layer in MozafariDeep Model #15

Open aidinattar opened 4 months ago

aidinattar commented 4 months ago

Hello,

First of all, I appreciate the work done on the SpykeTorch repository. It has been incredibly helpful for my research. I am working on upgrading the MozafariDeep model to study the behaviour of deeper network, I am trying to add a layer of R-STDP, but I'm encountering problems with poor accuracy. I have tried two different training strategies for the R-STDP layers:

  1. Simultaneous Update: Updating the weights of both layers based on the last output.
  2. Separate Update: Updating one layer while fixing the other, and then vice versa.

Both approaches resulted in poor accuracy. Additionally, I suspect there might be parameter issues such as thresholds, the number of winners, and inhibition radius. I understand that the paper mentions there is no significant improvement when adding an R-STDP layer, but I assume there shouldn't be any catastrophic behavior either. Any guidance on these problems would be greatly appreciated.

Snippets

Here some snippets of the code:

def __init__(
        self,
        num_classes = 10,
        learning_rate_multiplier = 1,  
    ):
        super(deepSNN, self).__init__()
        self.num_classes = num_classes

        #### LAYER 1 ####
        self.conv1 = snn.Convolution(
            in_channels=6,
            out_channels=30,
            kernel_size=4,
            weight_mean=0.8,
            weight_std=0.05
        )
        self.conv1_t = 15
        self.k1 = 5
        self.r1 = 3

        #### LAYER 2 ####
        self.conv2 = snn.Convolution(
            in_channels=30,
            out_channels=250,
            kernel_size=3,
            weight_mean=0.8,
            weight_std=0.05
        )
        self.conv2_t = 10
        self.k2 = 8
        self.r2 = 2

        #### LAYER 3 ####
        self.conv3 = snn.Convolution(
            in_channels=250,
            out_channels=200,
            kernel_size=3,
            weight_mean=0.8,
            weight_std=0.05
        )
        self.conv3_t = 400
        self.k3 = 15
        self.r3 = 0

        #### LAYER 4 ####
        self.conv4 = snn.Convolution(
            in_channels=200,
            out_channels=100,
            kernel_size=4,
            weight_mean=0.8,
            weight_std=0.05
        )

        # STDP
        self.stdp1 = snn.STDP(
            conv_layer = self.conv1,
            learning_rate = (
                learning_rate_multiplier * 0.004,
                learning_rate_multiplier * -0.003
            ),
        )
        self.stdp2 = snn.STDP(
            conv_layer = self.conv2,
            learning_rate = (
                learning_rate_multiplier * 0.004,
                learning_rate_multiplier * -0.003
            ),
        )
        self.stdp3 = snn.STDP(
            conv_layer = self.conv3,
            learning_rate = (
                learning_rate_multiplier * 0.004,
                learning_rate_multiplier * -0.003
            ),
            use_stabilizer = False,
            lower_bound = 0.2,
            upper_bound = 0.8,
        )
        self.stdp4 = snn.STDP(
            conv_layer = self.conv4,
            learning_rate = (
                learning_rate_multiplier * 0.004,
                learning_rate_multiplier * -0.003
            ),
            use_stabilizer = False,
            lower_bound = 0.2,
            upper_bound = 0.8,
        )

        # ANTI STDP
        self.anti_stdp3 = snn.STDP(
            conv_layer = self.conv3,
            learning_rate = (
                learning_rate_multiplier * -0.004,
                learning_rate_multiplier * 0.0005
            ),
            use_stabilizer = False,
            lower_bound = 0.2,
            upper_bound = 0.8,
        )
        self.anti_stdp4 = snn.STDP(
            conv_layer = self.conv4,
            learning_rate = (
                learning_rate_multiplier * -0.004,
                learning_rate_multiplier * 0.0005
            ),
            use_stabilizer = False,
            lower_bound = 0.2,
            upper_bound = 0.8,
        )

        # adaptive learning rate
        self.max_ap = Parameter(torch.tensor([0.15]))

        # Decision map
        self.decision_map = self.generate_decision_map()

        # context parameters
        self.ctx = {
            'input_spikes': None,
            'potentials': None,
            'output_spikes': None,
            'winners': None,
        }
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

        self.ctx3 = {
            'input_spikes': None,
            'potentials': None,
            'output_spikes': None,
            'winners': None,
        }
        # self.spk_cnt3 = 0

        self.ctx4 = {
            'input_spikes': None,
            'potentials': None,
            'output_spikes': None,
            'winners': None,
        }
def forward(
        self,
        input,
        layer_idx,
    ):
        """
        Forward pass of the network

        Parameters
        ----------
        input : torch.Tensor
            Input tensor
        layer_idx : int
            Layer index

        Returns
        -------
        int
            Output class
        """
        # padding to avoid edge effects
        input = sf.pad(
            input = input.float(),
            pad = (2, 2, 2, 2),
            value = 0
        )

        if self.training:
            # Layer 1
            # potential and spikes
            pot = self.conv1(input)
            spk, pot = sf.fire(
                potentials = pot,
                threshold = self.conv1_t,
                return_thresholded_potentials = True,
            )
            if layer_idx == 1:
                self.spk_cnt1 += 1
                if self.spk_cnt1 >= 500:
                    self.spk_cnt1 = 0
                    ap = torch.tensor(
                        self.stdp1.learning_rate[0][0].item(),
                        device = self.stdp1.learning_rate[0][0].device
                    ) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -.75
                    self.stdp1.update_all_learning_rate(
                        ap.item(),
                        an.item()
                    )

                # inhibition
                pot = sf.pointwise_inhibition(
                    thresholded_potentials = pot
                )
                spk = pot.sign()
                winners = sf.get_k_winners(
                    potentials = pot,
                    kwta = self.k1,
                    inhibition_radius = self.r1,
                    spikes = spk
                )
                self.ctx.update({
                    "input_spikes": input,
                    "potentials": pot,
                    "output_spikes": spk,
                    "winners": winners
                })
                return spk, pot

            # Layer 2
            # potential and spikes
            spk_in = sf.pad(
                sf.pooling(
                    input = spk,
                    kernel_size = 2,
                    stride = 2,
                ),
                pad = (2, 2, 2, 2),
            )
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(
                potentials = pot,
                threshold = self.conv2_t,
                return_thresholded_potentials = True,
            )
            if layer_idx == 2:
                self.spk_cnt2 += 1
                if self.spk_cnt2 >= 500:
                    self.spk_cnt2 = 0
                    ap = torch.tensor(
                        self.stdp2.learning_rate[0][0].item(),
                        device = self.stdp2.learning_rate[0][0].device
                    ) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -.75
                    self.stdp2.update_all_learning_rate(
                        ap.item(),
                        an.item()
                    )

                # inhibition
                pot = sf.pointwise_inhibition(
                    thresholded_potentials = pot
                )
                spk = pot.sign()
                winners = sf.get_k_winners(
                    potentials = pot,
                    kwta = self.k2,
                    inhibition_radius = self.r2,
                    spikes = spk
                )
                self.ctx.update({
                    "input_spikes": spk_in,
                    "potentials": pot,
                    "output_spikes": spk,
                    "winners": winners
                })
                return spk, pot

            # Layer 3
            # potential and spikes
            spk_in = sf.pad(
                sf.pooling(
                    input = spk,
                    kernel_size = 2,
                    stride = 2,
                ),
                pad = (2, 2, 2, 2),
            )
            pot = self.conv3(spk_in)
            spk, pot = sf.fire(
                potentials = pot,
                threshold = self.conv3_t,
                return_thresholded_potentials = True,
            )

            # self.spk_cnt3 += 1
            # if self.spk_cnt3 >= 500:
            #     self.spk_cnt3 = 0
            #     ap = torch.tensor(
            #         self.stdp3.learning_rate[0][0].item(),
            #         device = self.stdp3.learning_rate[0][0].device
            #     ) * 2
            #     ap = torch.min(ap, self.max_ap)
            #     an = ap * -.75
            #     self.stdp3.update_all_learning_rate(
            #         ap.item(),
            #         an.item()
            #     )

            if layer_idx == 3:
                winners = sf.get_k_winners(
                    potentials = pot,
                    kwta = self.k3,
                    inhibition_radius = self.r3,
                    spikes = spk
                )
                self.ctx3.update({
                    "input_spikes": spk_in,
                    "potentials": pot,
                    "output_spikes": spk,
                    "winners": winners
                })

            # Layer 4
            # potential and spikes
            spk_in = sf.pad(
                sf.pooling(
                    input = spk,
                    kernel_size = 3,
                    stride = 3,
                ),
                pad = (2, 2, 2, 2),
            )
            pot = self.conv4(spk_in)
            spk = sf.fire(
                potentials = pot,
            )
            winners = sf.get_k_winners(
                potentials = pot,
                kwta = 1,
                inhibition_radius = 0,
                spikes = spk
            )
            if layer_idx == 4:
                self.ctx4.update({
                    "input_spikes": spk_in,
                    "potentials": pot,
                    "output_spikes": spk,
                    "winners": winners
                })
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output
def train_rl(
    network,
    data,
    target,
    max_layer,
):
    """
    Train the network using reinforcement learning

    Parameters
    ----------
    network : deepSNN
        Network to be trained
    data : torch.Tensor
        Input data
    target : int
        Target class
    max_layer : int
        Maximum layer to be trained

    Returns
    -------
    None
    """
    network.train()
    perf = np.array([0, 0, 0]) # [correct, wrong, silence]
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()

        for layer_idx in range(3, max_layer + 1):
            d = network(
                data_in,
                layer_idx,
            )

            if layer_idx == max_layer:
                if d != -1:
                    if d == target_in:
                        perf[0] += 1
                        for j in range(3, max_layer + 1):
                            network.reward(j)
                    else:
                        perf[1] += 1
                        for j in range(3, max_layer + 1):
                            network.punish(j)
                else:
                    perf[2] += 1

    return perf / len(data)

def train_rl_separate(
    network,
    data,
    target,
    max_layer,
):
    """
    Train the network using reinforcement learning,
    but train each layer separately, fixing the weights
    of the previous and next layers

    Parameters
    ----------
    network : deepSNN
        Network to be trained
    data : torch.Tensor
        Input data
    target : int
        Target class
    max_layer : int
        Maximum layer to be trained

    Returns
    -------
    None
    """
    network.train()
    perf = np.array([0, 0, 0])  # [correct, wrong, silence]
    for i in range(len(data)):
        for layer_idx in range(3, max_layer + 1):
            data_in = data[i]
            target_in = target[i]
            if use_cuda:
                data_in = data_in.cuda()
                target_in = target_in.cuda()

            d = network(
                data_in,
                layer_idx,
            )

            if d != -1:
                if d == target_in:
                    if layer_idx == max_layer:
                        perf[0] += 1
                    network.reward(layer_idx)
                else:
                    if layer_idx == max_layer:
                        perf[1] += 1
                    network.punish(layer_idx)
            else:
                if layer_idx == max_layer:
                    perf[2] += 1
    return perf / len(data)

I apologize for any obvious mistakes or unclear parts in my code or explanation. I am ready to provide any additional information or clarification if needed.

Thank you again for the fantastic work on this repository and for your assistance with my issue.

Best regards, Aidin