Xiaobin-Rong / gtcrn

The official implementation of GTCRN, an ultra-lite speech enhancement model.
MIT License
219 stars 37 forks source link

导出onnx的stream模型时可以优化一点点的两个方法 #19

Open SherryYu33 opened 5 months ago

SherryYu33 commented 5 months ago
  1. SFE模块的unfold可以用如下模块代替,可以减少很多算子

    import torch
    import torch.nn as nn
    class Unfold(nn.Module):
    def __init__(self):
        super().__init__()
        kernel = torch.eye(3)
        kernel = kernel.view(3, 1, 1, 3)
        kernel = nn.Parameter(kernel.repeat(8, 1, 1, 1))
        self.conv = nn.Conv2d(8, 24, (1, 3), padding=(0, 1), groups=8, bias=False)
        self.conv.weight = kernel
    
    def forward(self, x):
        out = self.conv(x)
        return out
  2. onnxsim没办法把ConvTranspose和BN融合在一起,但是pnnx可以,可以节省算力,用如下方法导出若干文件
    import pnnx
    mod = torch.jit.trace(model_stream, [一堆变量])
    mod.save("gtcrn.pt")
    opt_net = pnnx.convert("gtcrn.pt", [一堆变量])

    然后会在当前文件夹生成一个gtcrn_pnnx.py的文件,里面有一个export_onnx()的函数,可以按喜好修改输出形式,最后当然也可以用onnxsim再跑一次

    export_onnx()
    import onnx
    from onnxsim import simplify
    onnx_model = onnx.load('gtcrn.onnx')
    onnx.checker.check_model(onnx_model)
    model_simp, check = simplify(onnx_model)
    onnx.save(model_simp, 'gtcrn_sim.onnx')
Xiaobin-Rong commented 5 months ago

@SherryYu33 非常感谢您的建议,受教了!

GuanHengcong commented 4 months ago
  1. SFE模块的unfold可以用如下模块代替,可以减少很多计算子
import torch
import torch.nn as nn
class Unfold(nn.Module):
    def __init__(self):
        super().__init__()
        kernel = torch.eye(3)
        kernel = kernel.view(3, 1, 1, 3)
        kernel = nn.Parameter(kernel.repeat(8, 1, 1, 1))
        self.conv = nn.Conv2d(8, 24, (1, 3), padding=(0, 1), groups=8, bias=False)
        self.conv.weight = kernel

    def forward(self, x):
        out = self.conv(x)
        return out
  1. onnxsim没办法把ConvTranspose和BN融合在一起,但是pnnx可以,可以节省算力,用如下方法导出若干文件
import pnnx
mod = torch.jit.trace(model_stream, [一堆变量])
mod.save("gtcrn.pt")
opt_net = pnnx.convert("gtcrn.pt", [一堆变量])

然后会在当前文件夹生成一个gtcrn_pnnx.py文件,里面有一个export_onnx()的函数,可以按喜好修改输出形式,最后当然也可以用onnxsim再运行一次

export_onnx()
import onnx
from onnxsim import simplify
onnx_model = onnx.load('gtcrn.onnx')
onnx.checker.check_model(onnx_model)
model_simp, check = simplify(onnx_model)
onnx.save(model_simp, 'gtcrn_sim.onnx')

大佬,我按照你说的改了gtcrn_stream.py文件中的, image 但是报了如下的错误: image

SherryYu33 commented 4 months ago

@GuanHengcong 因为重构以后的Unfold它的名称就和原来state_dict里面的对不上了,最简单的办法就是把convert_to_stream里面的

else:
    raise (....)

给注释了 以及SFE的后面也不用reshape

GuanHengcong commented 4 months ago

@GuanHengcong 因为重构以后的Unfold它的名称就和原来state_dict里面的对不上了,最简单的办法就是把convert_to_stream里面的

else:
    raise (....)

给注释了 以及SFE的后面也不用reshape

大佬,很抱歉没有及时回复您消息,我把reshape和else都去掉了,报错如下,您可以再帮忙看看吗 image image image 在“torch.onnx.export(”导出这行报错

SherryYu33 commented 4 months ago

@GuanHengcong 整体模型的里面的那个SFE的channel数是3,你把那个换成

import torch
import torch.nn as nn
class Unfold_in(nn.Module):
    def __init__(self):
        super().__init__()
        kernel = torch.eye(3)
        kernel = kernel.view(3, 1, 1, 3)
        kernel = nn.Parameter(kernel.repeat(3, 1, 1, 1))
        self.conv = nn.Conv2d(3, 9, (1, 3), padding=(0, 1), groups=3, bias=False)
        self.conv.weight = kernel

    def forward(self, x):
        out = self.conv(x)
        return out
GuanHengcong commented 4 months ago

@GuanHengcong 整体模型的里面的那个SFE的channel数是3,你把那个换成

import torch
import torch.nn as nn
class Unfold_in(nn.Module):
    def __init__(self):
        super().__init__()
        kernel = torch.eye(3)
        kernel = kernel.view(3, 1, 1, 3)
        kernel = nn.Parameter(kernel.repeat(3, 1, 1, 1))
        self.conv = nn.Conv2d(3, 9, (1, 3), padding=(0, 1), groups=3, bias=False)
        self.conv.weight = kernel

    def forward(self, x):
        out = self.conv(x)
        return out

佬,还是报类似的错误,形状对应不上,我可以用您的微信或者QQ联系您吗,实在是太打扰了,再拖下去我怕是要明年毕业了,/(ㄒoㄒ)/~~

SherryYu33 commented 4 months ago

@GuanHengcong 帮帮铺(bbpu)抠图

shenbuguanni commented 4 months ago

我测试了上面SFE的优化写法在板子上的实际实时率,基本无变化。

xczhusuda commented 2 months ago

我们发现GTCRN的功耗比相同算力的模型高很多 @Xiaobin-Rong @SherryYu33