ioyy900205 / MFNet

This repo provides the processed samples of the manuscript "a Mask Free Neural Network for Monaural Speech Enhancement", which was accepted by INTERSPEECH2023.
MIT License
36 stars 4 forks source link

Will the model be released in the future? #1

Open wl-junlin opened 1 year ago

ioyy900205 commented 1 year ago

Thank you for the question raised on GitHub. I would like to inform you that I will not be sharing my code publicly. However, if you have any questions or concerns regarding my paper or the reproducibility of my work, please feel free to bring them up.

yudashuixiao1 commented 1 year ago

请问enhanced_DNS的对应的输入有没有,可否贴出来

ioyy900205 commented 1 year ago

请问enhanced_DNS的对应的输入有没有,可否贴出来 您好,这里是对应输入的noisy语音 https://github.com/microsoft/DNS-Challenge/tree/interspeech2020/master/datasets/test_set/synthetic/no_reverb/noisy

Janvia commented 1 year ago

Gate是卷积核1x1的卷积吗,Pixelshuffle是3x3的卷积吗,我复现的模型,MACs少了一倍

ioyy900205 commented 1 year ago

Gate是卷积核1x1的卷积吗,Pixelshuffle是3x3的卷积吗,我复现的模型,MACs少了一倍

1.GATE

class SimpleGate(nn.Module): def forward(self, x): x1, x2 = x.chunk(2, dim=1) return x1 * x2

2.Pixelshuffle

nn.PixelShuffle(2)

ioyy900205 commented 1 year ago

我复现的模型,MACs少了一倍

您好,我需要知道您的输入特征,模型搭建方式才能判断哪里不一致。

Janvia commented 1 year ago

我模型的输入大小(1,1,800,320),计算MACs时再除8,得到3.6GMACs/s 下采样是2x2的卷积,proj是3x3的卷积,gate使用的1x1的卷积,pixelshuffle,使用的3x3的卷积同时把channel扩大4倍。 GLFB中,point conv是1x1卷积,dwconv是3x3分组卷积,GLFB通道数输入输出不变 网络整体通道变化是16,32,64,128,256,128,64,32,16 另外请问一下如果pixelshuffle不加卷积,通道会下降4倍,怎么上采样时通道下降2倍呢

ioyy900205 commented 1 year ago

我模型的输入大小(1,1,800,320),计算MACs时再除8,得到3.6GMACs/s 下采样是2x2的卷积,proj是3x3的卷积,gate使用的1x1的卷积,pixelshuffle,使用的3x3的卷积同时把channel扩大4倍。 GLFB中,point conv是1x1卷积,dwconv是3x3分组卷积,GLFB通道数输入输出不变 网络整体通道变化是16,32,64,128,256,128,64,32,16 另外请问一下如果pixelshuffle不加卷积,通道会下降4倍,怎么上采样时通道下降2倍呢

抱歉,论文中没说清楚。 具体上采样操作是 nn.Conv2d(channel, channel*2,1) + nn.PixelShuffle(2) 跳层使用直接add

我想这里和你的唯一差别是您使用的3x3卷积扩大4倍,我使用1x1卷积扩大2倍

需要注意的是,我FFN中,第一步 conv1x1通道数扩大了2倍 第二步simplegate 第三步 conv1x1 还原到初始通道

Janvia commented 1 year ago

我把Gate的1x1和pixelshuffle的3x3去掉以后,macs/s只有1.42了,只有你论文中有6.09。可以直接帮我看一下代码吗 import torch import torch.nn as nn from torch import Tensor

class hswish(nn.Module): def forward(self, x): out = x * F.relu6(x + 3, inplace=True) / 6 return out

class hsigmoid(nn.Module): def forward(self, x): out = F.relu6(x + 3, inplace=True) / 6 return out

class SeModule(nn.Module): def init(self, in_size, reduction=4): super(SeModule, self).init() expand_size = max(in_size // reduction, 8) self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False),

nn.BatchNorm2d(expand_size),

        nn.ReLU(inplace=True),
        nn.Conv2d(expand_size, in_size, kernel_size=1, bias=False),
        nn.Hardsigmoid()
    )

def forward(self, x):
    return x * self.se(x)

class GateConv2d(nn.Module): def init(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, pad = False, use_hsigmoid=False ): super(GateConv2d, self).init() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride k_t = kernel_size[0] k_f = kernel_size[1] if pad: if k_f == 3: pad_f = 1 elif k_f == 5: pad_f = 2 else: pad_f = 0 else: pad_f = 0 self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, stride=stride,padding=(pad_f,pad_f))

    self.sigmoid = hsigmoid()  if use_hsigmoid else nn.Sigmoid()

def forward(self, inputs: Tensor) -> Tensor:
    if inputs.dim() == 3:
        inputs = inputs.unsqueeze(dim=1)
    x = self.conv(inputs)
    #print(inputs.shape,x.shape)
    outputs, gate = x.chunk(2, dim=1)
    return outputs * self.sigmoid(gate)

class Gate(nn.Module): def init(self, C, use_hsigmoid=False ): super(Gate, self).init()

    #self.conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=1)
    self.sigmoid = hsigmoid()  if use_hsigmoid else nn.Sigmoid()

def forward(self, x: Tensor) -> Tensor:
    #x = self.conv(x)
    outputs, gate = x.chunk(2, dim=1)
    return outputs * self.sigmoid(gate)

class GLFB(nn.Module): '''expand + depthwise + pointwise'''

def __init__(self, kernel_size, C,  se, stride):
    super(GLFB, self).__init__()
    self.stride = stride

    self.norm = nn.InstanceNorm2d(C)
    self.pconv1 = nn.Conv2d(C, 2*C, kernel_size=1, bias=False)

    self.dconv1 = nn.Conv2d(2*C, 2*C, kernel_size=kernel_size, stride=stride,
                           padding=kernel_size // 2, groups=2*C, bias=False)
    #self.gate = GateConv2d(2*C,C,(3,3),(1,1),pad=True)
    self.gate = Gate(2*C)
    self.se = SeModule(C) if se else nn.Identity()

    self.pconv2 = nn.Conv2d(C, C, kernel_size=1, bias=False)

    self.skip = nn.Sequential(
        nn.InstanceNorm2d(C),
        nn.Conv2d(in_channels=C, out_channels=2*C, kernel_size=1,
                  bias=False),
        #GateConv2d(2*C,C,(3,3),(1,1),pad=True),
        Gate(2*C),
        nn.Conv2d(C, C, kernel_size=1, bias=False),
    )

def forward(self, x):
    skip = x
    #print(x.shape)
    out = self.pconv1(self.norm(x))
    out = self.gate(self.dconv1(out))
    out = self.se(out)
    out = self.pconv2(out)
    out = out + skip
    if self.skip is not None:
        skip = self.skip(skip)
    return out + skip

class PixelShuffleBlock(nn.Module): def init(self, in_channel, out_channel, upscale_factor=(2,2), kernel=(1,1), stride=1, padding=1): super(PixelShuffleBlock, self).init() expand = upscale_factor[0] * upscale_factor[1] if isinstance(upscale_factor,tuple) else upscale_factor ** 2

print(kernel)

    self.conv = nn.Conv2d(in_channel, out_channel * expand, kernel, stride, padding=padding,bias=False)
    self.ps = nn.PixelShuffle(upscale_factor[0])
    self.expand = expand
    self.upscale_factor = upscale_factor
    self.out_channel = out_channel
def forward(self, x):
    #print("in",x.shape)
    b,c,t,f = x.shape
    out = self.conv(x)
    out = self.ps(out)

    return out

import torch.nn.functional as F

class DownBlock(nn.Module): def init(self,in_dim, C=64,depth=1): super(DownBlock, self).init() self.conv = nn.Conv2d(indim, C, (2,2), stride=(2,2), padding=0, bias=False, groups=1) self.glfb = nn.Sequential(*[GLFB(3, C, se=True, stride=1) for in range(depth)])

def forward(self, x):
    out = self.conv(x)
    out = self.glfb(out)
    return out

class UpBlock(nn.Module): def init(self, in_dim,C=64,kenal_size=(1,1),depth = 1): super(UpBlock, self).init() self.conv = PixelShuffleBlock(in_channel=in_dim,out_channel=C,kernel=kenal_size, upscale_factor=(2,2),stride=1,padding=0)

    self.glfb = nn.Sequential(*[GLFB(3, C, se=True, stride=1) for _ in range(depth)])

def forward(self, x,skip):
    out = self.conv(x)
    out = out + skip
    out = self.glfb(out)
    return out

class Project(nn.Module): def init(self,in_dim, C=64,kenal_size=(3,3)): super(Project, self).init() self.conv = nn.Conv2d(in_dim,C,kenal_size,padding=1) def forward(self, x): conv_out = self.conv(x) return conv_out

class MF_Net(nn.Module):

def __init__(self,  fft_len=320, C = 16, causal=False,export=False):
    super(MF_Net, self).__init__()
    self.causal = causal

    self.first_block = nn.Sequential(Project(1,C),
                                     GLFB(3,C,se=True,stride=1))
    self.encoder = nn.ModuleList()
    self.decoder = nn.ModuleList()

    self.export = export

    num_layers = 4
    depths = [1,1,8,4]
    channels = [C*2,C*4,C*8,C*16]

    for i in range(num_layers):
        self.encoder.append(DownBlock(channels[i]//2,channels[i],depth=depths[i]))

    m = 6
    self.middle_block = nn.Sequential(*[GLFB(3, channels[-1], se=True, stride=1) for i in range(m)])

    us = [1,1,1,1]
    for i in range(num_layers, 0, -1):
        self.decoder.append( UpBlock(channels[i-1],channels[i-1]//2,depth=us[i-1]))

    self.last_block = nn.Sequential(Project(C,1))
def forward(self, inputs):
    """
    inputs: B x F x T
    """
    outs = self.first_block(inputs)
    #print(outs.shape)
    encoder_out = [outs]
    for i in range(len(self.encoder)):
        outs = self.encoder[i](outs)
        print("enc:",outs.shape)
        encoder_out.append(outs)

    outs = self.middle_block(outs)

    for i in range(len(self.decoder)):
        #print(outs.shape , encoder_out[-1 - i].shape)
        outs = self.decoder[i](outs , encoder_out[-2 - i])
       #
        print("dec",outs.shape)

    outs = self.last_block(outs)
    return outs

if name == "main": model = MF_Net() x = torch.randn([1,1,100*8,320]) out = model(x) print(out.shape)

from ptflops import get_model_complexity_info
import re

#Model thats already available
macs, params = get_model_complexity_info(model, tuple(x.shape[1:]), as_strings=True,
print_per_layer_stat=False, verbose=True)
# Extract the numerical value
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2
# Extract the unit
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0]

print('Computational complexity: {:<8}'.format(macs))
print('Computational complexity: {} {}Flops'.format(flops, flops_unit))
print('Number of parameters: {:<8}'.format(params))
print("MACs/s",float(str(macs).split(' ')[0])/8)
ioyy900205 commented 1 year ago

我把Gate的1x1和pixelshuffle的3x3去掉以后,macs/s只有1.42了,只有你论文中有6.09。可以直接帮我看一下代码吗 import torch import torch.nn as nn from torch import Tensor

class hswish(nn.Module): def forward(self, x): out = x * F.relu6(x + 3, inplace=True) / 6 return out

class hsigmoid(nn.Module): def forward(self, x): out = F.relu6(x + 3, inplace=True) / 6 return out

class SeModule(nn.Module): def init(self, in_size, reduction=4): super(SeModule, self).init() expand_size = max(in_size // reduction, 8) self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False), # nn.BatchNorm2d(expand_size), nn.ReLU(inplace=True), nn.Conv2d(expand_size, in_size, kernel_size=1, bias=False), nn.Hardsigmoid() )

def forward(self, x):
    return x * self.se(x)

class GateConv2d(nn.Module): def init(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, pad = False, use_hsigmoid=False ): super(GateConv2d, self).init() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride k_t = kernel_size[0] k_f = kernel_size[1] if pad: if k_f == 3: pad_f = 1 elif k_f == 5: pad_f = 2 else: pad_f = 0 else: pad_f = 0 self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, stride=stride,padding=(pad_f,pad_f))

    self.sigmoid = hsigmoid()  if use_hsigmoid else nn.Sigmoid()

def forward(self, inputs: Tensor) -> Tensor:
    if inputs.dim() == 3:
        inputs = inputs.unsqueeze(dim=1)
    x = self.conv(inputs)
    #print(inputs.shape,x.shape)
    outputs, gate = x.chunk(2, dim=1)
    return outputs * self.sigmoid(gate)

class Gate(nn.Module): def init(self, C, use_hsigmoid=False ): super(Gate, self).init()

    #self.conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=1)
    self.sigmoid = hsigmoid()  if use_hsigmoid else nn.Sigmoid()

def forward(self, x: Tensor) -> Tensor:
    #x = self.conv(x)
    outputs, gate = x.chunk(2, dim=1)
    return outputs * self.sigmoid(gate)

class GLFB(nn.Module): '''expand + depthwise + pointwise'''

def __init__(self, kernel_size, C,  se, stride):
    super(GLFB, self).__init__()
    self.stride = stride

    self.norm = nn.InstanceNorm2d(C)
    self.pconv1 = nn.Conv2d(C, 2*C, kernel_size=1, bias=False)

    self.dconv1 = nn.Conv2d(2*C, 2*C, kernel_size=kernel_size, stride=stride,
                           padding=kernel_size // 2, groups=2*C, bias=False)
    #self.gate = GateConv2d(2*C,C,(3,3),(1,1),pad=True)
    self.gate = Gate(2*C)
    self.se = SeModule(C) if se else nn.Identity()

    self.pconv2 = nn.Conv2d(C, C, kernel_size=1, bias=False)

    self.skip = nn.Sequential(
        nn.InstanceNorm2d(C),
        nn.Conv2d(in_channels=C, out_channels=2*C, kernel_size=1,
                  bias=False),
        #GateConv2d(2*C,C,(3,3),(1,1),pad=True),
        Gate(2*C),
        nn.Conv2d(C, C, kernel_size=1, bias=False),
    )

def forward(self, x):
    skip = x
    #print(x.shape)
    out = self.pconv1(self.norm(x))
    out = self.gate(self.dconv1(out))
    out = self.se(out)
    out = self.pconv2(out)
    out = out + skip
    if self.skip is not None:
        skip = self.skip(skip)
    return out + skip

class PixelShuffleBlock(nn.Module): def init(self, in_channel, out_channel, upscale_factor=(2,2), kernel=(1,1), stride=1, padding=1): super(PixelShuffleBlock, self).init() expand = upscale_factor[0] * upscale_factor[1] if isinstance(upscale_factor,tuple) else upscale_factor * 2 #print(kernel) self.conv = nn.Conv2d(in_channel, out_channel expand, kernel, stride, padding=padding,bias=False) self.ps = nn.PixelShuffle(upscale_factor[0]) self.expand = expand self.upscale_factor = upscale_factor self.out_channel = out_channel def forward(self, x): #print("in",x.shape) b,c,t,f = x.shape out = self.conv(x) out = self.ps(out)

    return out

import torch.nn.functional as F

class DownBlock(nn.Module): def init(self,in_dim, C=64,depth=1): super(DownBlock, self).init() self.conv = nn.Conv2d(indim, C, (2,2), stride=(2,2), padding=0, bias=False, groups=1) self.glfb = nn.Sequential(*[GLFB(3, C, se=True, stride=1) for in range(depth)])

def forward(self, x):
    out = self.conv(x)
    out = self.glfb(out)
    return out

class UpBlock(nn.Module): def init(self, in_dim,C=64,kenal_size=(1,1),depth = 1): super(UpBlock, self).init() self.conv = PixelShuffleBlock(in_channel=in_dim,out_channel=C,kernel=kenal_size, upscale_factor=(2,2),stride=1,padding=0)

    self.glfb = nn.Sequential(*[GLFB(3, C, se=True, stride=1) for _ in range(depth)])

def forward(self, x,skip):
    out = self.conv(x)
    out = out + skip
    out = self.glfb(out)
    return out

class Project(nn.Module): def init(self,in_dim, C=64,kenal_size=(3,3)): super(Project, self).init() self.conv = nn.Conv2d(in_dim,C,kenal_size,padding=1) def forward(self, x): conv_out = self.conv(x) return conv_out

class MF_Net(nn.Module):

def __init__(self,  fft_len=320, C = 16, causal=False,export=False):
    super(MF_Net, self).__init__()
    self.causal = causal

    self.first_block = nn.Sequential(Project(1,C),
                                     GLFB(3,C,se=True,stride=1))
    self.encoder = nn.ModuleList()
    self.decoder = nn.ModuleList()

    self.export = export

    num_layers = 4
    depths = [1,1,8,4]
    channels = [C*2,C*4,C*8,C*16]

    for i in range(num_layers):
        self.encoder.append(DownBlock(channels[i]//2,channels[i],depth=depths[i]))

    m = 6
    self.middle_block = nn.Sequential(*[GLFB(3, channels[-1], se=True, stride=1) for i in range(m)])

    us = [1,1,1,1]
    for i in range(num_layers, 0, -1):
        self.decoder.append( UpBlock(channels[i-1],channels[i-1]//2,depth=us[i-1]))

    self.last_block = nn.Sequential(Project(C,1))
def forward(self, inputs):
    """
    inputs: B x F x T
    """
    outs = self.first_block(inputs)
    #print(outs.shape)
    encoder_out = [outs]
    for i in range(len(self.encoder)):
        outs = self.encoder[i](outs)
        print("enc:",outs.shape)
        encoder_out.append(outs)

    outs = self.middle_block(outs)

    for i in range(len(self.decoder)):
        #print(outs.shape , encoder_out[-1 - i].shape)
        outs = self.decoder[i](outs , encoder_out[-2 - i])
       #
        print("dec",outs.shape)

    outs = self.last_block(outs)
    return outs

if name == "main": model = MF_Net() x = torch.randn([1,1,100*8,320]) out = model(x) print(out.shape)

from ptflops import get_model_complexity_info
import re

#Model thats already available
macs, params = get_model_complexity_info(model, tuple(x.shape[1:]), as_strings=True,
print_per_layer_stat=False, verbose=True)
# Extract the numerical value
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2
# Extract the unit
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0]

print('Computational complexity: {:<8}'.format(macs))
print('Computational complexity: {} {}Flops'.format(flops, flops_unit))
print('Number of parameters: {:<8}'.format(params))
print("MACs/s",float(str(macs).split(' ')[0])/8)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out

class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out

class SeModule(nn.Module):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        expand_size = max(in_size // reduction, 8)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, in_size, kernel_size=1, bias=False),
        )

    def forward(self, x):
        return x * self.se(x)

class Gate(nn.Module):
    def __init__(self, C, use_hsigmoid=False):
        super(Gate, self).__init__()

        # self.conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=1)
        self.sigmoid = hsigmoid() if use_hsigmoid else nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        # x = self.conv(x)
        outputs, gate = x.chunk(2, dim=1)
        return outputs * self.sigmoid(gate)

class GLFB(nn.Module):

    """expand + depthwise + pointwise"""

    def __init__(self, kernel_size, C, se, stride):
        super(GLFB, self).__init__()
        self.stride = stride

        self.norm = nn.InstanceNorm2d(C)
        self.pconv1 = nn.Conv2d(C, 2 * C, kernel_size=1, bias=True)

        self.dconv1 = nn.Conv2d(
            2 * C,
            2 * C,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            groups=2 * C,
            bias=False,
        )
        self.gate = Gate(2 * C)
        self.se = SeModule(C) if se else nn.Identity()

        self.pconv2 = nn.Conv2d(C, C, kernel_size=1, bias=False)

        self.beta = nn.Parameter(torch.zeros((1, C, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, C, 1, 1)), requires_grad=True)

        self.skip = nn.Sequential(
            nn.InstanceNorm2d(C),
            nn.Conv2d(in_channels=C, out_channels=2 * C, kernel_size=1, bias=False),
            # GateConv2d(2*C,C,(3,3),(1,1),pad=True),
            Gate(2 * C),
            nn.Conv2d(C, C, kernel_size=1, bias=False),
        )

    def forward(self, x):
        skip = x
        # print(x.shape)
        out = self.pconv1(self.norm(x))
        out = self.gate(self.dconv1(out))
        out = self.se(out)
        out = self.pconv2(out)
        out = out * self.beta + skip
        if self.skip is not None:
            skip = self.skip(skip)
        return out + skip * self.gamma

class PixelShuffleBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        upscale_factor=(2, 2),
        kernel=(1, 1),
        stride=1,
        padding=1,
    ):
        super(PixelShuffleBlock, self).__init__()
        expand = (
            upscale_factor[0] * upscale_factor[1]
            if isinstance(upscale_factor, tuple)
            else upscale_factor**2
        )
        # print(kernel)
        self.conv = nn.Conv2d(
            in_channel,
            out_channel * expand,
            kernel,
            stride,
            padding=padding,
            bias=False,
        )
        self.ps = nn.PixelShuffle(upscale_factor[0])
        self.expand = expand
        self.upscale_factor = upscale_factor
        self.out_channel = out_channel

    def forward(self, x):
        # print("in",x.shape)
        b, c, t, f = x.shape
        out = self.conv(x)
        out = self.ps(out)
        return out

class DownBlock(nn.Module):
    def __init__(self, in_dim, C=64, depth=1):
        super(DownBlock, self).__init__()
        self.conv = nn.Conv2d(
            in_dim, C, (2, 2), stride=(2, 2), padding=0, bias=False, groups=1
        )
        self.glfb = nn.Sequential(
            *[GLFB(3, C, se=True, stride=1) for _ in range(depth)]
        )

    def forward(self, x):
        out = self.conv(x)
        out = self.glfb(out)
        return out

class UpBlock(nn.Module):
    def __init__(self, in_dim, C=64, kenal_size=(1, 1), depth=1):
        super(UpBlock, self).__init__()
        self.conv = PixelShuffleBlock(
            in_channel=in_dim,
            out_channel=C,
            kernel=kenal_size,
            upscale_factor=(2, 2),
            stride=1,
            padding=0,
        )

        self.glfb = nn.Sequential(
            *[GLFB(3, C, se=True, stride=1) for _ in range(depth)]
        )

    def forward(self, x, skip):
        out = self.conv(x)
        out = out + skip
        out = self.glfb(out)
        return out

class Projection(nn.Module):
    def __init__(self, in_dim=1, C=64, kenal_size=(3, 3)):
        super(Projection, self).__init__()
        self.conv = nn.Conv2d(in_dim, C, kenal_size, padding=1)

    def forward(self, x):
        conv_out = self.conv(x)
        return conv_out

class MF_Net(nn.Module):
    def __init__(self, fft_len=320, C=32, causal=False, export=False):
        super(MF_Net, self).__init__()
        self.causal = causal
        self.first_block = nn.Sequential(Projection(C=C), GLFB(3, C, se=True, stride=1))
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        self.export = export
        num_layers = 4
        depths = [1, 1, 8, 4]
        channels = [C * 2, C * 4, C * 8, C * 16]

        for i in range(num_layers):
            self.encoder.append(
                DownBlock(channels[i] // 2, channels[i], depth=depths[i])
            )

        m = 6
        self.middle_block = nn.Sequential(
            *[GLFB(3, channels[-1], se=True, stride=1) for i in range(m)]
        )

        us = [1, 1, 1, 1]
        for i in range(num_layers, 0, -1):
            self.decoder.append(
                UpBlock(channels[i - 1], channels[i - 1] // 2, depth=us[i - 1])
            )

        self.last_block = nn.Sequential(Projection(C, 1))

    def forward(self, inputs):
        """
        inputs: B x F x T
        """
        outs = self.first_block(inputs)
        # print(outs.shape)
        encoder_out = [outs]
        for i in range(len(self.encoder)):
            outs = self.encoder[i](outs)
            print("enc:", outs.shape)
            encoder_out.append(outs)

        outs = self.middle_block(outs)

        for i in range(len(self.decoder)):
            # print(outs.shape , encoder_out[-1 - i].shape)
            outs = self.decoder[i](outs, encoder_out[-2 - i])
            #
            print("dec", outs.shape)

        outs = self.last_block(outs)
        return outs

if __name__ == "__main__":
    model = MF_Net()
    x = torch.randn([1, 1, 100 * 8, 320])
    out = model(x)
    print(out.shape)

    import re

    from ptflops import get_model_complexity_info

    # Model thats already available
    macs, params = get_model_complexity_info(
        model,
        tuple(x.shape[1:]),
        as_strings=True,
        print_per_layer_stat=False,
        verbose=True,
    )
    # Extract the numerical value
    flops = eval(re.findall(r"([\d.]+)", macs)[0]) * 2
    # Extract the unit
    flops_unit = re.findall(r"([A-Za-z]+)", macs)[0][0]

    print("Computational complexity: {:<8}".format(macs))
    print("Computational complexity: {} {}Flops".format(flops, flops_unit))
    print("Number of parameters: {:<8}".format(params))
    print("MACs/s", float(str(macs).split(" ")[0] / 8.0))

我这边稍微改了下,还需要修改的地方有:1.norm方式和我不一致,2.需要在dw卷积中使用dialation , dialation=num_layer**2.

目前,大概计算量差不多了(ps:我使用thop测试这个模型计算量6.08GMACs/s的样子, 但是ptflops统计会少一点5.586.08GMACs/s)。我觉得可以在这个基础上改一改,再看看结果。

Janvia commented 1 year ago

好的,感谢大佬指导

ioyy900205 commented 1 year ago

好的,感谢大佬指导

没关系,你有什么问题随时联系哈!

heeyounMoon commented 10 months ago

Hello. I am currently implementing your MFNET using VCTK dataset. I think I implemented the same model structure mentioned in the paper, but the performance is not that good. I recently saw this issue and added dilation to dw conv, but it was the same. Can you take a look at my model code and tell me what's wrong?

import torch
import torch.nn as nn
import torch.nn.functional as F

def down_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, kernel_size=2, stride=2, bias=False),
        )
    )

def up_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, 1, bias=False),  # [B, C*2, H, W]
            nn.PixelShuffle(2),     # [B, C/2, H*2, W*2]
        )
    )

class LayerNormChannel(nn.Module):
    """
    Metaformer Layer Norm : https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py#L86
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))    # Learnable
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class GLFB(nn.Module):
    def __init__(self, ch_in):
        super(GLFB, self).__init__()
        # [B, C, H, W]
        self.layernorm_1 = LayerNormChannel(ch_in)

        self.pw_1 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.dw = nn.Conv2d(ch_in*2, ch_in*2, 3, 1, padding=16, dilation=16, groups=ch_in*2, bias=False)
        self.gate_1 = SimpleGate()

        self.ch_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False),
        )

        self.pw_2 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

        self.beta = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)

        self.layernorm_2 = LayerNormChannel(ch_in)
        self.pw_3 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.gate_2 = SimpleGate()
        self.pw_4 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

    def forward(self, inp):
        # [B, C, H, W]
        # [B, 1, 256, T]
        x = self.layernorm_1(inp)                   # Layer Norm
        x = self.pw_1(x)                            # Point Conv
        x = self.dw(x)                              # DW Conv
        x = self.gate_1(x)                          # Gate
        x = x * self.ch_attention(x)                # Channel Attention
        x = self.pw_2(x)                            # Point Conv

        y = inp + x                                 # Add
        # y = inp + x * self.beta                     # Add

        x = self.layernorm_2(y)                     # Layer Norm
        x = self.pw_3(x)                            # Point Conv
        x = self.gate_2(x)                          # Gate
        x = self.pw_4(x)                            # Point Conv

        return x + y                                # Add
        # return y + x * self.gamma                  # Add

class NS(nn.Module):
    def __init__(self, model_ch=16, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
        # projection
        self.in_proj = nn.Conv2d(1, model_ch, 3, 1, 1)

        self.in_glfb = GLFB(model_ch)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        ch = model_ch
        for num in down_glfb:
            self.downs.append(down_sampling(ch))
            ch = ch * 2
            self.encoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )

        self.mid_down = down_sampling(ch)
        ch = ch * 2
        self.middle_blks = \
            nn.Sequential(
                *[GLFB(ch) for _ in range(mid_glfb)]
            )
        self.mid_up = up_sampling(ch)
        ch = ch // 2

        for num in up_glfb:
            self.decoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )
            self.ups.append(up_sampling(ch))
            ch = ch // 2

        assert ch == model_ch
        self.out_glfb = GLFB(model_ch)

        # projection
        self.out_proj = nn.Conv2d(model_ch, 1, 3, 1, 1)

    def forward(self, x):
        # x [B, C(1), H(320), W(T)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        ori_shape = x.shape[-1]     # T
        x = F.pad(x, (0, 16 - (x.shape[-1] % 16)))    # [B, 1, 320, T_pad]

        x = self.in_proj(x)     # [B, 16, H, W]
        x = self.in_glfb(x)
        skip = x

        encs = []
        for encoder, down in zip(self.encoders, self.downs):
            x = down(x)
            x = encoder(x)
            encs.append(x)  # [3]

        x = self.mid_down(x)
        x = self.middle_blks(x)
        x = self.mid_up(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = x + enc_skip
            x = decoder(x)
            x = up(x)

        x = x + skip
        x = self.out_glfb(x)
        x = self.out_proj(x)    # [B, 1, H*, W*]

        return x.squeeze(1)[:, :, :ori_shape]
ioyy900205 commented 10 months ago

Hello. I am currently implementing your MFNET using VCTK dataset. I think I implemented the same model structure mentioned in the paper, but the performance is not that good. I recently saw this issue and added dilation to dw conv, but it was the same. Can you take a look at my model code and tell me what's wrong?

import torch
import torch.nn as nn
import torch.nn.functional as F

def down_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, kernel_size=2, stride=2, bias=False),
        )
    )

def up_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, 1, bias=False),  # [B, C*2, H, W]
            nn.PixelShuffle(2),     # [B, C/2, H*2, W*2]
        )
    )

class LayerNormChannel(nn.Module):
    """
    Metaformer Layer Norm : https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py#L86
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))    # Learnable
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class GLFB(nn.Module):
    def __init__(self, ch_in):
        super(GLFB, self).__init__()
        # [B, C, H, W]
        self.layernorm_1 = LayerNormChannel(ch_in)

        self.pw_1 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.dw = nn.Conv2d(ch_in*2, ch_in*2, 3, 1, padding=16, dilation=16, groups=ch_in*2, bias=False)
        self.gate_1 = SimpleGate()

        self.ch_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False),
        )

        self.pw_2 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

        self.beta = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)

        self.layernorm_2 = LayerNormChannel(ch_in)
        self.pw_3 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.gate_2 = SimpleGate()
        self.pw_4 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

    def forward(self, inp):
        # [B, C, H, W]
        # [B, 1, 256, T]
        x = self.layernorm_1(inp)                   # Layer Norm
        x = self.pw_1(x)                            # Point Conv
        x = self.dw(x)                              # DW Conv
        x = self.gate_1(x)                          # Gate
        x = x * self.ch_attention(x)                # Channel Attention
        x = self.pw_2(x)                            # Point Conv

        y = inp + x                                 # Add
        # y = inp + x * self.beta                     # Add

        x = self.layernorm_2(y)                     # Layer Norm
        x = self.pw_3(x)                            # Point Conv
        x = self.gate_2(x)                          # Gate
        x = self.pw_4(x)                            # Point Conv

        return x + y                                # Add
        # return y + x * self.gamma                  # Add

class NS(nn.Module):
    def __init__(self, model_ch=16, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
        # projection
        self.in_proj = nn.Conv2d(1, model_ch, 3, 1, 1)

        self.in_glfb = GLFB(model_ch)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        ch = model_ch
        for num in down_glfb:
            self.downs.append(down_sampling(ch))
            ch = ch * 2
            self.encoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )

        self.mid_down = down_sampling(ch)
        ch = ch * 2
        self.middle_blks = \
            nn.Sequential(
                *[GLFB(ch) for _ in range(mid_glfb)]
            )
        self.mid_up = up_sampling(ch)
        ch = ch // 2

        for num in up_glfb:
            self.decoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )
            self.ups.append(up_sampling(ch))
            ch = ch // 2

        assert ch == model_ch
        self.out_glfb = GLFB(model_ch)

        # projection
        self.out_proj = nn.Conv2d(model_ch, 1, 3, 1, 1)

    def forward(self, x):
        # x [B, C(1), H(320), W(T)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        ori_shape = x.shape[-1]     # T
        x = F.pad(x, (0, 16 - (x.shape[-1] % 16)))    # [B, 1, 320, T_pad]

        x = self.in_proj(x)     # [B, 16, H, W]
        x = self.in_glfb(x)
        skip = x

        encs = []
        for encoder, down in zip(self.encoders, self.downs):
            x = down(x)
            x = encoder(x)
            encs.append(x)  # [3]

        x = self.mid_down(x)
        x = self.middle_blks(x)
        x = self.mid_up(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = x + enc_skip
            x = decoder(x)
            x = up(x)

        x = x + skip
        x = self.out_glfb(x)
        x = self.out_proj(x)    # [B, 1, H*, W*]

        return x.squeeze(1)[:, :, :ori_shape]

According to the code you provided, I made some modifications after reviewing it. The model parameters now match exactly: 15,989,857. In terms of computational load, my original code was 6.09053G, and after the modifications, it is now 6.03093G, which is a reduction of approximately 60M overall. However, I have checked the logic, and it is essentially consistent. You can test this result, keeping in mind that the input should be "BCTF". I am also interested in knowing the results of your retesting on VCTK.


import torch
import torch.nn as nn
import torch.nn.functional as F

def down_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, kernel_size=2, stride=2, bias=True),
        )
    )

def up_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, 1, bias=False),  # [B, C*2, H, W]
            nn.PixelShuffle(2),     # [B, C/2, H*2, W*2]
        )
    )

class LayerNormChannel(nn.Module):
    """
    Metaformer Layer Norm : https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py#L86
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))    # Learnable
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class GLFB(nn.Module):
    def __init__(self, ch_in, dialation=1):
        super(GLFB, self).__init__()
        # [B, C, H, W]
        self.layernorm_1 = LayerNormChannel(ch_in)

        self.pw_1 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=True)
        padding_num = (3-1)*dialation
        padding = (padding_num //2, padding_num - padding_num//2,
                   padding_num //2, padding_num - padding_num//2,)
        self.pad = nn.ZeroPad2d(padding)
        self.dw = nn.Conv2d(ch_in*2, ch_in*2, kernel_size=3, stride=1, padding=0, dilation=dialation, groups=ch_in*2, bias=True)
        self.gate_1 = SimpleGate()

        self.ch_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True),
        )

        self.pw_2 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True)

        self.beta = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)

        self.layernorm_2 = LayerNormChannel(ch_in)
        self.pw_3 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=True)
        self.gate_2 = SimpleGate()
        self.pw_4 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True)

    def forward(self, inp):
        # [B, C, H, W]
        # [B, 1, 256, T]

        x = self.layernorm_1(inp)                   # Layer Norm
        x = self.pw_1(x) 
        x = self.pad(x)                           # Point Conv
        x = self.dw(x)                              # DW Conv
        x = self.gate_1(x)                          # Gate
        x = x * self.ch_attention(x)                # Channel Attention
        x = self.pw_2(x)                            # Point Conv

        # y = inp + x                                 # Add
        y = inp + x * self.beta                     # Add

        x = self.layernorm_2(y)                     # Layer Norm
        x = self.pw_3(x)                            # Point Conv
        x = self.gate_2(x)                          # Gate
        x = self.pw_4(x)                            # Point Conv

        # return x + y                                # Add
        return y + x * self.gamma                  # Add

class NS(nn.Module):
    def __init__(self, model_ch=32, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
        # projection
        self.in_proj = nn.Conv2d(1, model_ch, 3, 1, 1)

        self.in_glfb = GLFB(model_ch, dialation=1)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        ch = model_ch
        for num in down_glfb:
            self.downs.append(down_sampling(ch))
            ch = ch * 2
            self.encoders.append(
                nn.Sequential(
                    *[GLFB(ch,dialation=2**(_+1)) for _ in range(num)]
                )
            )

        self.mid_down = down_sampling(ch)
        ch = ch * 2
        self.middle_blks = \
            nn.Sequential(
                *[GLFB(ch) for _ in range(mid_glfb)]
            )
        self.mid_up = up_sampling(ch)
        ch = ch // 2

        for num in up_glfb:
            self.decoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )
            self.ups.append(up_sampling(ch))
            ch = ch // 2

        assert ch == model_ch
        self.out_glfb = GLFB(model_ch)

        # projection
        self.out_proj = nn.Conv2d(model_ch, 1, 3, 1, 1)
        self.padder = 2**(len(self.encoders)+1)

    def pad(self,x):
        _,_,t,f = x.shape
        T_pad = (self.padder - t % self.padder) % self.padder
        F_pad = (self.padder - f % self.padder) % self.padder
        x = F.pad(x,(0,F_pad,0,T_pad))
        return x

    def forward(self, x):
        # x [B, C(1), H(320), W(T)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        B,C,H,W = x.shape
        #pad x

        x = self.pad(x)
        inp = x
        x = self.in_proj(x)     
        encs = []
        x = self.in_glfb(x)
        encs.append(x)

        for encoder, down in zip(self.encoders, self.downs):
            x = down(x)
            x = encoder(x)
            encs.append(x)  

        x = self.mid_down(x)
        x = self.middle_blks(x)
        x = self.mid_up(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = x + enc_skip
            x = decoder(x)
            x = up(x)

        x = x + encs[0]
        x = self.out_glfb(x)
        x = self.out_proj(x)    # [B, 1, H*, W*]

        x = x + inp
        return x[..., :H,:W]

if __name__ == "__main__":
    net = NS()
    input = torch.randn(1,1,99,320) #BCTF
    output = net(input)
    print('done')

    def cout_param(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(cout_param(net))

    from thop import clever_format
    from thop import profile

    input = torch.randn(1, 1, 99, 320)

    flops, params = profile(net, inputs=(input, ))
    flops, params = clever_format([flops, params], "%.5f")

    print(flops,params)
heeyounMoon commented 10 months ago

yes. Thank you so much for your kind reply. I will train again with the code you modified and let you know the results. This is a different question. Currently, I am processing the audio signal frame by frame, and one frame is 80000 (5 seconds) long. In other words, all data is concatenated and then cut to 80000 lengths. I think this could also affect performance. Are there any special data processing methods you use?

ioyy900205 commented 10 months ago

yes. Thank you so much for your kind reply. I will train again with the code you modified and let you know the results. This is a different question. Currently, I am processing the audio signal frame by frame, and one frame is 80000 (5 seconds) long. In other words, all data is concatenated and then cut to 80000 lengths. I think this could also affect performance. Are there any special data processing methods you use?

Thank you for your attention. I believe that data is crucial, especially in real-world projects. Currently, I have identified two approaches: 1) offline mode and 2) online mode. Through my experiments, I have found that these two methods do not significantly impact the results, but it is essential to ensure that the network has fully converged.

I believe that handling data requires consideration of effective augmentation methods. For instance, adjusting the appropriate Signal-to-Noise Ratio (SNR) mixing range (I have come across different configurations in various papers, such as -3 to 15 dB in DNS and 0-20 dB), employing suitable augmentations (convolutional Room Impulse Responses, spectral masks, time masks, pitch transformations, speed transformations, etc.), and filtering clean data (I observed this method being used in the ByteDance DNS competition and Google's [L2L-audiovisual] related work).

If you concatenate all speech into a long audio and train on 5-second segments, I recommend using the roll method to increase the diversity of training data. My experiments also indicate that there is little difference between using 5-second and 30-second training segments. In fact, whether it's a causal or non-causal network, you can view the neural network as a filter with a 5-second buffer. Therefore, you can ignore concerns about the length of the 5-second training.

JangyeonKim commented 4 months ago

Hello. I am currently implementing your MFNET using VCTK dataset. I think I implemented the same model structure mentioned in the paper, but the performance is not that good. I recently saw this issue and added dilation to dw conv, but it was the same. Can you take a look at my model code and tell me what's wrong?

import torch
import torch.nn as nn
import torch.nn.functional as F

def down_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, kernel_size=2, stride=2, bias=False),
        )
    )

def up_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, 1, bias=False),  # [B, C*2, H, W]
            nn.PixelShuffle(2),     # [B, C/2, H*2, W*2]
        )
    )

class LayerNormChannel(nn.Module):
    """
    Metaformer Layer Norm : https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py#L86
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))    # Learnable
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class GLFB(nn.Module):
    def __init__(self, ch_in):
        super(GLFB, self).__init__()
        # [B, C, H, W]
        self.layernorm_1 = LayerNormChannel(ch_in)

        self.pw_1 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.dw = nn.Conv2d(ch_in*2, ch_in*2, 3, 1, padding=16, dilation=16, groups=ch_in*2, bias=False)
        self.gate_1 = SimpleGate()

        self.ch_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False),
        )

        self.pw_2 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

        self.beta = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)

        self.layernorm_2 = LayerNormChannel(ch_in)
        self.pw_3 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=False)
        self.gate_2 = SimpleGate()
        self.pw_4 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=False)

    def forward(self, inp):
        # [B, C, H, W]
        # [B, 1, 256, T]
        x = self.layernorm_1(inp)                   # Layer Norm
        x = self.pw_1(x)                            # Point Conv
        x = self.dw(x)                              # DW Conv
        x = self.gate_1(x)                          # Gate
        x = x * self.ch_attention(x)                # Channel Attention
        x = self.pw_2(x)                            # Point Conv

        y = inp + x                                 # Add
        # y = inp + x * self.beta                     # Add

        x = self.layernorm_2(y)                     # Layer Norm
        x = self.pw_3(x)                            # Point Conv
        x = self.gate_2(x)                          # Gate
        x = self.pw_4(x)                            # Point Conv

        return x + y                                # Add
        # return y + x * self.gamma                  # Add

class NS(nn.Module):
    def __init__(self, model_ch=16, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
        # projection
        self.in_proj = nn.Conv2d(1, model_ch, 3, 1, 1)

        self.in_glfb = GLFB(model_ch)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        ch = model_ch
        for num in down_glfb:
            self.downs.append(down_sampling(ch))
            ch = ch * 2
            self.encoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )

        self.mid_down = down_sampling(ch)
        ch = ch * 2
        self.middle_blks = \
            nn.Sequential(
                *[GLFB(ch) for _ in range(mid_glfb)]
            )
        self.mid_up = up_sampling(ch)
        ch = ch // 2

        for num in up_glfb:
            self.decoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )
            self.ups.append(up_sampling(ch))
            ch = ch // 2

        assert ch == model_ch
        self.out_glfb = GLFB(model_ch)

        # projection
        self.out_proj = nn.Conv2d(model_ch, 1, 3, 1, 1)

    def forward(self, x):
        # x [B, C(1), H(320), W(T)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        ori_shape = x.shape[-1]     # T
        x = F.pad(x, (0, 16 - (x.shape[-1] % 16)))    # [B, 1, 320, T_pad]

        x = self.in_proj(x)     # [B, 16, H, W]
        x = self.in_glfb(x)
        skip = x

        encs = []
        for encoder, down in zip(self.encoders, self.downs):
            x = down(x)
            x = encoder(x)
            encs.append(x)  # [3]

        x = self.mid_down(x)
        x = self.middle_blks(x)
        x = self.mid_up(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = x + enc_skip
            x = decoder(x)
            x = up(x)

        x = x + skip
        x = self.out_glfb(x)
        x = self.out_proj(x)    # [B, 1, H*, W*]

        return x.squeeze(1)[:, :, :ori_shape]

According to the code you provided, I made some modifications after reviewing it. The model parameters now match exactly: 15,989,857. In terms of computational load, my original code was 6.09053G, and after the modifications, it is now 6.03093G, which is a reduction of approximately 60M overall. However, I have checked the logic, and it is essentially consistent. You can test this result, keeping in mind that the input should be "BCTF". I am also interested in knowing the results of your retesting on VCTK.

import torch
import torch.nn as nn
import torch.nn.functional as F

def down_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, kernel_size=2, stride=2, bias=True),
        )
    )

def up_sampling(ch_in):
    return (
        nn.Sequential(
            nn.Conv2d(ch_in, ch_in*2, 1, bias=False),  # [B, C*2, H, W]
            nn.PixelShuffle(2),     # [B, C/2, H*2, W*2]
        )
    )

class LayerNormChannel(nn.Module):
    """
    Metaformer Layer Norm : https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py#L86
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))    # Learnable
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class GLFB(nn.Module):
    def __init__(self, ch_in, dialation=1):
        super(GLFB, self).__init__()
        # [B, C, H, W]
        self.layernorm_1 = LayerNormChannel(ch_in)

        self.pw_1 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=True)
        padding_num = (3-1)*dialation
        padding = (padding_num //2, padding_num - padding_num//2,
                   padding_num //2, padding_num - padding_num//2,)
        self.pad = nn.ZeroPad2d(padding)
        self.dw = nn.Conv2d(ch_in*2, ch_in*2, kernel_size=3, stride=1, padding=0, dilation=dialation, groups=ch_in*2, bias=True)
        self.gate_1 = SimpleGate()

        self.ch_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True),
        )

        self.pw_2 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True)

        self.beta = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, ch_in, 1, 1)), requires_grad=True)

        self.layernorm_2 = LayerNormChannel(ch_in)
        self.pw_3 = nn.Conv2d(ch_in, ch_in*2, 1, 1, 0, bias=True)
        self.gate_2 = SimpleGate()
        self.pw_4 = nn.Conv2d(ch_in, ch_in, 1, 1, 0, bias=True)

    def forward(self, inp):
        # [B, C, H, W]
        # [B, 1, 256, T]

        x = self.layernorm_1(inp)                   # Layer Norm
        x = self.pw_1(x) 
        x = self.pad(x)                           # Point Conv
        x = self.dw(x)                              # DW Conv
        x = self.gate_1(x)                          # Gate
        x = x * self.ch_attention(x)                # Channel Attention
        x = self.pw_2(x)                            # Point Conv

        # y = inp + x                                 # Add
        y = inp + x * self.beta                     # Add

        x = self.layernorm_2(y)                     # Layer Norm
        x = self.pw_3(x)                            # Point Conv
        x = self.gate_2(x)                          # Gate
        x = self.pw_4(x)                            # Point Conv

        # return x + y                                # Add
        return y + x * self.gamma                  # Add

class NS(nn.Module):
    def __init__(self, model_ch=32, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
        # projection
        self.in_proj = nn.Conv2d(1, model_ch, 3, 1, 1)

        self.in_glfb = GLFB(model_ch, dialation=1)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        ch = model_ch
        for num in down_glfb:
            self.downs.append(down_sampling(ch))
            ch = ch * 2
            self.encoders.append(
                nn.Sequential(
                    *[GLFB(ch,dialation=2**(_+1)) for _ in range(num)]
                )
            )

        self.mid_down = down_sampling(ch)
        ch = ch * 2
        self.middle_blks = \
            nn.Sequential(
                *[GLFB(ch) for _ in range(mid_glfb)]
            )
        self.mid_up = up_sampling(ch)
        ch = ch // 2

        for num in up_glfb:
            self.decoders.append(
                nn.Sequential(
                    *[GLFB(ch) for _ in range(num)]
                )
            )
            self.ups.append(up_sampling(ch))
            ch = ch // 2

        assert ch == model_ch
        self.out_glfb = GLFB(model_ch)

        # projection
        self.out_proj = nn.Conv2d(model_ch, 1, 3, 1, 1)
        self.padder = 2**(len(self.encoders)+1)

    def pad(self,x):
        _,_,t,f = x.shape
        T_pad = (self.padder - t % self.padder) % self.padder
        F_pad = (self.padder - f % self.padder) % self.padder
        x = F.pad(x,(0,F_pad,0,T_pad))
        return x

    def forward(self, x):
        # x [B, C(1), H(320), W(T)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        B,C,H,W = x.shape
        #pad x

        x = self.pad(x)
        inp = x
        x = self.in_proj(x)     
        encs = []
        x = self.in_glfb(x)
        encs.append(x)

        for encoder, down in zip(self.encoders, self.downs):
            x = down(x)
            x = encoder(x)
            encs.append(x)  

        x = self.mid_down(x)
        x = self.middle_blks(x)
        x = self.mid_up(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = x + enc_skip
            x = decoder(x)
            x = up(x)

        x = x + encs[0]
        x = self.out_glfb(x)
        x = self.out_proj(x)    # [B, 1, H*, W*]

        x = x + inp
        return x[..., :H,:W]

if __name__ == "__main__":
    net = NS()
    input = torch.randn(1,1,99,320) #BCTF
    output = net(input)
    print('done')

    def cout_param(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(cout_param(net))

    from thop import clever_format
    from thop import profile

    input = torch.randn(1, 1, 99, 320)

    flops, params = profile(net, inputs=(input, ))
    flops, params = clever_format([flops, params], "%.5f")

    print(flops,params)

Hello, Thank you for your kind response. It has been very helpful.

I have a question. I am trying to replicate the performance on the DNS dataset using the model structure mentioned above. Since the dataloader brings in inputs in waveform format, I added a part to convert it to stdct in the forward section of the model, but the performance is not turning out well.

    def to_stdct(self, x):
        B, L = x.shape

        transformed = []
        for i in range(B):
            transformed.append(stdct(x[i], 320, 160))

        return torch.stack(transformed)

     def forward(self, x):
        # ⚡ x [B, L] 
        x = self.to_stdct(x)

        # x [B, C(1), T, F(320)]
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # zero padding for down/up sampling
        B,C,H,W = x.shape

         ...

I referred to the stdct() function from "https://github.com/taotaowang97479/MFNet-SpeechEnhancement." According to this GitHub page, the input to the model is compressed spectra (i.e., input = sign(stdct) * sqrt(stdct)). Is this correct?

In this case, should I expect the model's output to also be compressed spectra? Or is the model's output supposed to be the original spectra? It gets confusing when calculating the loss. I am currently proceeding with the original spectra as the input.

And my loss function is,

from model.MFNet.STDCT import stdct, istdct

class MFNetLoss(nn.Module):
    def __init__(self, gamma=0.5):
        super(MFNetLoss, self).__init__()
        self.gamma = gamma
        self.mse_loss = nn.MSELoss() 

    def forward(self, S_pred, S_true):
        S_true = stdct(S_true, 320, 160)

        # Mean-Square Error (MSE) loss for absolute values
        Loss_abs = self.mse_loss(torch.abs(S_true), torch.abs(S_pred))

        # MSE loss for polar values
        Loss_polar = self.mse_loss(S_true, S_pred)

        Loss_MFNet = self.gamma * Loss_abs + (1 - self.gamma) * Loss_polar

        return Loss_MFNet

I would appreciate any advice you can offer. Thank you.

ioyy900205 commented 4 months ago

@JangyeonKim Hello, thanks for your attention.

  1. Need some modification. input = sign(stdct) * sqrt(stdct)) should be *input = torch.sign(stdct) torch.sqrt(stdct.abs()))**
  2. Model output should also be compressed STDCT. The translation of the pseudocode should be: loss_cal(compressed_pred_stdct, compressed_pred_clean).

"Please note that during the speech recovery process, the compressed DCT (Discrete Cosine Transform) needs to be restored to the regular DCT. I will reiterate the overall process: wav -> split_frame -> add window -> dct -> compressed dct -> neural network -> normal dct -> idct (Inverse Discrete Cosine Transform) -> remove window -> over_lap_and_add. Be sure to pay attention: if a root-Hanning window is used, both the windowing and dewindowing processes should involve multiplying the frames by the window."

JangyeonKim commented 4 months ago

Thank you for your answer; it was very helpful.

I have one more question: according to the MFNet paper, the number of channels is listed as [n, 2n, 4n, 8n, 16n, 8n, 4n, 2n, n], where n = 16. However, in the code, model_ch is set to 32. Which configuration showed the best performance?


class NS(nn.Module):
    def __init__(self, model_ch=32, mid_glfb=6, down_glfb=[1, 8, 4], up_glfb=[1, 1, 1]):
        super(NS, self).__init__()
``
ioyy900205 commented 4 months ago

@JangyeonKim Thank you for your attention. The value of 𝑛 n needs to be set to 32 so that the computational and parameter volume of the model is consistent with the paper.