RookieJunChen / FullSubNet-plus

The official PyTorch implementation of "FullSubNet+: Channel Attention FullSubNet with Complex Spectrograms for Speech Enhancement".
Apache License 2.0
243 stars 55 forks source link

About the efficiency of the MulCA module #12

Open 747929791 opened 2 years ago

747929791 commented 2 years ago

I am trying to reproduce the FullSubNet+ on some speech enhancement datasets. The results are amazing, the noise suppression ability of this method is so good, and very impressive! :star_struck::star_struck::star_struck:

I stumbled across an implementation detail in the paper and code that piqued my curiosity. Regarding the paper in the MulCA module (if I'm not misunderstanding, its code is implemented in ChannelTimeSenseSELayer ). Three concurrently processed nn.Sequential are used here, each sequence in turn contains Depthwise Conv1d, AdaptiveAvgPool1d, and ReLU. These features are then subsequently merged together using fully connected layers.

One question that puzzles me is that if the order of operations is Conv1d and then AdaptiveAvgPool1d is used, based on the distributive law of multiplication, it seems that the process of convolution can be approximated basically by the following simplified form (stride=1): AvgPool1d(Conv1d(A,weight,bias)) ≈ A.sum(-1)*weight.sum(-1)/(A.shape[-1]-kernel_size+1)+bias + small_sided_error (may be similar to the above formula for subband_num>1)

We might be able to define weight.sum(-1)/(A.shape[-1]-kernel_size+1) and bias as two float32 parameters, then use the Maxout activation function to combine the three-way convolution into a simple channel summation.

Are there any special considerations for the design of MulCA through Conv1d? I think the simplified implementation is very similar to ChannelSELayer.

RookieJunChen commented 2 years ago

Thank you for your positive review of FullSubNet+! As I described in my paper, this work was indeed inspired by ChannelSElayer. But we wanted to make more use of some information on the temporal structure, so we used a multi-scale 1D convolution. This is my introduction to MulCA at the time of the rebuttal:

MulCA is a channel attention mechanism that treats sub-bands as channels. The main contribution of MulCA is to assign different weights to different sub-bands in the spectrogram so that the model focuses on the sub-bands that are more discriminative for the noise reduction task. Different from other channel attention mechanisms, MulCA module internally uses multi-scale depth-wise convolution to extract features of the frenquency bin at different time scales as the basis for subsequent computation of channel attention. The convolution is performed independently over each input channel. That is, each convolution kernel extracts temporal features on only one frequency bin. This makes the MulCA module more suitable for processing speech signals on time-frequency domain with smaller parameters and computational cost compared to other channel attention mechanisms. To our best knowledge, we are the first one to propose and apply such channel attention mechanism (i.e. MulCA) to noise reduction tasks. We believe that this lightweight MulCA module can also be combined with other existing state-of-the-art models to further enhance their speech enhancement capabilities.

747929791 commented 2 years ago

Thanks for your reply! I think there might be some confusion here. What confused me is that AvgPool1d appeared after Conv1d seems to make Conv1d cannot extract any features. Perhaps the multi-channel attention mechanism should be implemented in other ways here.

Here's a mathematically equivalent version I've implemented:

def Equivalent_Form_MulCA(x:torch.Tensor, module:nn.Module):
            x: (batch, channel, time)
            module: MulCA nn.Sequential (smallConv1d/middleConv1d/largeConv1d) :
                self.smallConv1d = nn.Sequential(
                    nn.Conv1d(num_channels, num_channels, kernel_size=kersize[0], 
                                groups=num_channels // subband_num),
                    nn.AdaptiveAvgPool1d(1),  # [B, num_channels, 1]
    # The output of the original implementation
    B, C, T = x.size()                              # [batch, channel, time]
    original_result = module(x)                     # [B, C, 1]

    # Use the more simplified equivalent form
    kernel_size = module[0].kernel_size[0]          # 3/5/10
    weight = module[0].weight / (T-kernel_size+1)   # [B, 1, kernal_size]
    bias = module[0].bias.reshape(1, -1, 1)         # [1,B,1]
    weight_sum = weight.sum(-1).unsqueeze(0)        # [1, B, 1]
    approximate_result = x.sum(-1, keepdim=True) * weight_sum + bias  # [B, C, 1]

    # Correct the small error between approximate_result and the original_result.
    # This step seems to be lacking in motivation, as it seems that the model 
    # should not focus on the few frames on either side of the spectrogram.
    approximation_error = (0.0*approximate_result).squeeze(-1) # [B, C]
    for k in range(kernel_size):
        approximation_error += weight[:, 0, k].unsqueeze(0)*(x[:, :, :k].sum(-1)+x[:, :, T-kernel_size+k+1:].sum(-1))
    simplified_result = approximate_result - approximation_error.unsqueeze(-1) # [B, C, 1]
    simplified_result = simplified_result.clamp(min=0) # ReLU

    print("error ratio:", (approximation_error**2).sum()/(simplified_result**2).sum()) # ~ 0.001

    # check original_result and simplified_result is the same
    assert torch.allclose(simplified_result, original_result, atol=1e-3, rtol=1e-3)

It seems that ChannelTimeSenseSELayer (MulCA module in the paper) can be implemented in a mathematically nearly equivalent simplified way (ignore the differences in learning rates on different parameters):

# Original version of MulCA
class ChannelTimeSenseSELayer(nn.Module):
    def __init__(self, num_channels, reduction_ratio=2, kersize=[3, 5, 10], subband_num=1):
        self.smallConv1d = nn.Sequential(
            nn.Conv1d(num_channels, num_channels, kernel_size=kersize[0], groups=num_channels // subband_num),
            nn.AdaptiveAvgPool1d(1),  # [B, num_channels, 1]
        self.middleConv1d = nn.Sequential(
            nn.Conv1d(num_channels, num_channels, kernel_size=kersize[1], groups=num_channels // subband_num),
            nn.AdaptiveAvgPool1d(1),  # [B, num_channels, 1]
        self.largeConv1d = nn.Sequential(
            nn.Conv1d(num_channels, num_channels, kernel_size=kersize[2], groups=num_channels // subband_num),
            nn.AdaptiveAvgPool1d(1),  # [B, num_channels, 1]

    def forward(self, input_tensor):
        # Extracting multi-scale information in the time dimension
        small_feature = self.smallConv1d(input_tensor)
        middle_feature = self.middleConv1d(input_tensor)
        large_feature = self.largeConv1d(input_tensor)

        feature =[small_feature, middle_feature, large_feature], dim=2)  # [B, num_channels, 3]
# Simplified version of MulCA
class Simplified_ChannelTimeSenseSELayer(nn.Module):
    def __init__(self, num_channels, reduction_ratio=2, kersize=[3, 5, 10], subband_num=1):
        self.weight_sum = torch.nn.parameter.Parameter(torch.ones(1, num_channels, 3))
        self.bias = torch.nn.parameter.Parameter(torch.zeros(1, num_channels, 3))

    def forward(self, input_tensor):
        # Extracting multi-scale information in the time dimension
        feature = self.relu(input_tensor.mean(-1, keepdim=True) * self.weight_sum + self.bias)

That's why I said the simplified implementation is very similar to ChannelSELayer, because the MulCA module seems to have only one more FC1x3->ReLU->FC3x1 step than ChannelSELayer.

RookieJunChen commented 2 years ago

Thanks for your detailed comments!

Thanks for your reply! I think there might be some confusion here. What confused me is that AvgPool1d appeared after Conv1d seems to make Conv1d cannot extract any features. Perhaps the multi-channel attention mechanism should be implemented in other ways here.

On this point, in fact, I personally compared when doing experiments, after using the multi-scale Conv1d, the performance is still improved compared to the previous, so I think this part still has some role to play.

747929791 commented 2 years ago

Indeed this improvement makes performance a bit better. I'm just curious about its core mechanics. Whether it resembles a regularization strategy or a data augmentation strategy. And whether there are effective alternatives and optimization spaces in terms of storage footprint and computational load.

Thank you for your reply! :grinning:

RookieJunChen commented 2 years ago

I didn't think about this as much as you did when I did this work before! So I haven't tried this point you mentioned either. Your idea is really valuable, maybe you can try this comparative experiment?