Xiaobin-Rong / gtcrn

The official implementation of GTCRN, an ultra-lite speech enhancement model.
MIT License
219 stars 37 forks source link

请问,如果想扩展一下,用计算量换模型的性能,应该朝哪个方向调参? #7

Closed zuowanbushiwo closed 9 months ago

zuowanbushiwo commented 9 months ago

您好,恭喜您完成了一个非常棒的工作,同时非常感谢您无私的分享。 就像您说的这个模型参数和运算量都是非常低的,不知道有没有一个稍微大一点 参数和运算量的模型,同时能达到更好的效果,应该朝哪个方向调参? 谢谢!

Xiaobin-Rong commented 9 months ago

感谢你的关注和支持! 如果想用参数量/计算量换性能,最好的方法就是简单地把Encoder/Decoder的channels调大就行,同时注意DPGRNN的input_size/hidden_size也相应调整,举个例子把channels由16改为32:

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.en_convs = nn.ModuleList([
            ConvBlock(3*3, 16, (1,5), stride=(1,2), padding=(0,2), use_deconv=False, is_last=False),
            ConvBlock(16, 32, (1,5), stride=(1,2), padding=(0,2), groups=2, use_deconv=False, is_last=False),
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(0,1), dilation=(1,1), use_deconv=False),
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(0,1), dilation=(2,1), use_deconv=False),
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(0,1), dilation=(5,1), use_deconv=False)
        ])

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.de_convs = nn.ModuleList([
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(2*5,1), dilation=(5,1), use_deconv=True),
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(2*2,1), dilation=(2,1), use_deconv=True),
            GTConvBlock(32, 32, (3,3), stride=(1,1), padding=(2*1,1), dilation=(1,1), use_deconv=True),
            ConvBlock(32, 16, (1,5), stride=(1,2), padding=(0,2), groups=2, use_deconv=True, is_last=False),
            ConvBlock(16, 2, (1,5), stride=(1,2), padding=(0,2), use_deconv=True, is_last=True)
        ])

class GTCRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.erb = ERB(65, 64)
        self.sfe = SFE(3, 1)

        self.encoder = Encoder()

        self.dpgrnn1 = DPGRNN(32, 33, 32)
        self.dpgrnn2 = DPGRNN(32, 33, 32)

        self.decoder = Decoder()

        self.mask = Mask()

这样子模型的复杂度大约是75K参数, 92 MMACs。

zuowanbushiwo commented 9 months ago

非常感谢回答,真的是太热心了。 你文章有做这样的对比实验吗?这样做提升大吗? 还有一个比较想问的问题:这个模型的泛化性能好不好?不知道你文章中有没有这方面的对比实验,做一些没见过噪声的测试。或者用VCTK-DEMAND训练的模型,在DNS的blind_test数据集中表现好不好? 总之非常感谢您的开源和上面的回答,这个对我帮助很大。谢谢!

Xiaobin-Rong commented 9 months ago
  1. 我文章没有做将模型复杂度放大的对比实验,因为文章focus的就是极小规模的模型。但是根据我平时做过的一些实验,这么做提升挺大的(将channel从16->32,在一些测试集上PESQ可以提高0.12),因为毕竟目前模型size太小了;
  2. 模型的泛化性能目前没有展开做研究,但是据我的经验,泛化性能更多地与 训练数据的规模 & 模型规模 相关。我的模型,用VCTK-DEMAND训练,在DNS的blind_test数据集中测试表现肯定会劣化,但这个问题也是其它端到端模型的通病吧。我后续的工作会聚焦在提高小模型泛化性这方面来。
zuowanbushiwo commented 9 months ago

好的,非常感谢您专业的回答 祝您科研顺利