BICLab / Spike-Driven-Transformer-V2

Offical implementation of "Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips" (ICLR2024)
https://openreview.net/forum?id=1SIBN5Xyw7
103 stars 16 forks source link

Re-Parameterization at Inference Time #4

Closed alexandercantrell closed 3 months ago

alexandercantrell commented 3 months ago

I was hoping you might be able to explain a little bit more about how you re-parameterize your RepConv block at inference time. In your paper, you reference Ding et al. 2021 (RepVGG: Making VGG-style ConvNets Great Again), but their block structure is significantly different to the one in your available code seen here:

class RepConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        bias=False,
    ):
        super().__init__()
        # hidden_channel = in_channel
        conv1x1 = nn.Conv2d(in_channel, in_channel, 1, 1, 0, bias=False, groups=1)
        bn = BNAndPadLayer(pad_pixels=1, num_features=in_channel)
        conv3x3 = nn.Sequential(
            nn.Conv2d(in_channel, in_channel, 3, 1, 0, groups=in_channel, bias=False),
            nn.Conv2d(in_channel, out_channel, 1, 1, 0, groups=1, bias=False),
            nn.BatchNorm2d(out_channel),
        )

        self.body = nn.Sequential(conv1x1, bn, conv3x3)

    def forward(self, x):
        return self.body(x)

The RepVGG block has it's convolutions in parallel, whereas this block seems to be sequential which makes it harder to re-parameterize. So far I've been able to figure out how to partially re-parameterize the block, by combining one 1x1 convolution with the 3x3 convolution, but I'm struggling to figure out how you combined that with the final 1x1 convolution. Thanks in advance for your help!

jkhu29 commented 3 months ago

Thanks for your interest. 1x1 convolution is identified as linear, you can fuse its weight $W{1\times1}$ into the weight of convolution $W{3\times3}$ by $W{new}=W{1\times1} \times W_{3\times3}$. I suggest you check out more literature on re-parameterization design such as: VanillaNet and DBB.

alexandercantrell commented 3 months ago

Thanks for the quick response and the information! After taking another shot at the math I think I finally got it.