Le-Xiaohuai-speech / DPCRN_DNS3

Implementation of paper "DPCRN: Dual-Path Convolution Recurrent Network for Single Channel Speech Enhancement"
188 stars 41 forks source link

您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 #15

Closed shinobuwz closed 2 years ago

shinobuwz commented 2 years ago

非实时处理的结果明显优于模拟实时处理(单帧输入单帧输出的)。 我不明白问题是出在训练时encoder模块的卷积层提取特征时依赖了相邻两帧的卷积结果,还是LSTM层没有足够的帧学习上下文。 这里是附件 pytorch_dpcrn.zip

Le-Xiaohuai-speech commented 2 years ago

非实时处理的效果肯定会高于实时,但是我不太清楚你说的明显是高多少。在实现非实时的时候有很多可能,LSTM是双向的,normalization是global,卷积层有一定look ahead,具体我看看你的实现。

shinobuwz commented 2 years ago

这里是对比文件,hx是使用你的实时模型跑出来的,另外两个是我复现后一个实时一个非实时的。 对比文件.zip

Le-Xiaohuai-speech commented 2 years ago

差距还是挺明显的,训练数据是一致的吗?

shinobuwz commented 2 years ago

噪声集是使用的RIRS里的噪声,人声集用的是其他开源数据集中的中文人声。和你的模型有差距是在我意料之中的。但是就我自己复现的结果实时和非实时的频谱图差距也非常的大orz

Le-Xiaohuai-speech commented 2 years ago

请问你RT实现是用那个RT2还是RealTime.py,你在单帧输入单帧输出的时候考虑到卷积的缓存吗?

------------------ 原始邮件 ------------------ 发件人: "Le-Xiaohuai-speech/DPCRN_DNS3" @.>; 发送时间: 2022年3月17日(星期四) 下午2:47 @.>; @.**@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15)

这里是对比文件,hx是使用你的实时模型跑出来的,另外两个是我复现后一个实时一个非实时的。 对比文件.zip

— Reply to this email directly, view it on GitHub, or unsubscribe. Triage notifications on the go with GitHub Mobile for iOS or Android. You are receiving this because you commented.Message ID: @.***>

shinobuwz commented 2 years ago

Realtime.py。没有考虑到卷积的缓存,只考虑到了LSTM隐藏状态的传递。。。

请问你RT实现是用那个RT2还是RealTime.py,你在单帧输入单帧输出的时候考虑到卷积的缓存吗? ------------------ 原始邮件 ------------------ 发件人: "Le-Xiaohuai-speech/DPCRN_DNS3" @.>; 发送时间: 2022年3月17日(星期四) 下午2:47 @.>; @.**@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15) 这里是对比文件,hx是使用你的实时模型跑出来的,另外两个是我复现后一个实时一个非实时的。 对比文件.zip — Reply to this email directly, view it on GitHub, or unsubscribe. Triage notifications on the go with GitHub Mobile for iOS or Android. You are receiving this because you commented.Message ID: @.***>

Le-Xiaohuai-speech commented 2 years ago

那应该是这个问题了,卷积网络实时的时候就挺麻烦的,每一层上一时刻的输入要作为缓存留着这一次计算。我推荐你把卷积核的时域维度设置为1 就没这个问题了

shinobuwz commented 2 years ago

好的,明白了!我先试试看保留缓存。

andyye1999 commented 1 year ago

您好,请问您解决pytorch卷积的缓存的问题了吗

Le-Xiaohuai-speech commented 1 year ago

请用下面的卷积实现代替原有的Conv2D,这个实现在转换onnx的时候,会把cache输出。代码给了流式和非流式的对齐例子。

乐笑怀-南京大学 @.***

 

------------------ 原始邮件 ------------------ 发件人: "Le-Xiaohuai-speech/DPCRN_DNS3" @.>; 发送时间: 2023年5月26日(星期五) 中午1:59 @.>; @.**@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15)

您好,请问您解决pytorch卷积的缓存的问题了吗

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

andyye1999 commented 1 year ago

请用下面的卷积实现代替原有的Conv2D,这个实现在转换onnx的时候,会把cache输出。代码给了流式和非流式的对齐例子。 乐笑怀-南京大学 @.   ------------------ 原始邮件 ------------------ 发件人: "Le-Xiaohuai-speech/DPCRN_DNS3" @.>; 发送时间: 2023年5月26日(星期五) 中午1:59 @.>; @*.**@*.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15) 您好,请问您解决pytorch卷积的缓存的问题了吗 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>

不好意思,我这边没显示您提到的卷积,好像乱码了
image

Le-Xiaohuai-speech commented 1 year ago

没有附件吗

---原始邮件--- 发件人: @.> 发送时间: 2023年5月26日(周五) 下午5:20 收件人: @.>; 抄送: @.**@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15)

请用下面的卷积实现代替原有的Conv2D,这个实现在转换onnx的时候,会把cache输出。代码给了流式和非流式的对齐例子。 乐笑怀-南京大学 @.   … ------------------ 原始邮件 ------------------ 发件人: "Le-Xiaohuai-speech/DPCRN_DNS3" @.>; 发送时间: 2023年5月26日(星期五) 中午1:59 @.>; @.@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15) 您好,请问您解决pytorch卷积的缓存的问题了吗 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>

不好意思,我这边没显示您提到的卷积,好像乱码了

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

Le-Xiaohuai-speech commented 1 year ago

流式卷积

class StreamConv(nn.Module):
    def __init__(self, 
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, int]],
                 stride: Union[int, Tuple[int, int]] = 1,
                 padding: Union[str, int, Tuple[int, int]] = 0,
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1,
                 bias: bool = True,
                 *args, **kargs):
        super(StreamConv, self).__init__(*args, **kargs)
        """
        流式卷积实现。
        默认 kernel_size = [T_size, F_size]
        """
        self.Conv2d = nn.Conv2d(in_channels = in_channels, 
                                out_channels = out_channels,
                                kernel_size = kernel_size,
                                stride = stride,
                                padding = padding,
                                dilation = dilation,
                                groups = groups,
                                bias = bias)

        self.in_channels = in_channels
        self.out_channels = out_channels
        if type(kernel_size) is int:
            self.T_size = kernel_size
            self.F_size = kernel_size
        elif type(kernel_size) in [list, tuple]:
            self.T_size, self.F_size = kernel_size
        else:
            raise ValueError('Invalid kernel size')

    def forward(self, x, cache):
        """
        x: [bs,C,1,F]
        cache: [bs,C,T-1,F]
        """
        inp = torch.cat([cache,x], dim = 2)
        outp = self.Conv2d(inp)
        # 这里也可以输出x,把更新cache放到外面
        out_cache = inp[:,:,1:]
        return outp, out_cache
Le-Xiaohuai-speech commented 1 year ago

流式转置卷积,可以通过卷积和上采样实现,只不过权重是反向的


class StreamConvTranspose(nn.Module):
    def __init__(self, 
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, int]],
                 stride: Union[int, Tuple[int, int]] = 1,
                 padding: Union[str, int, Tuple[int, int]] = 0,
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1,
                 bias: bool = True,
                 *args, **kargs):
        super(StreamConvTranspose, self).__init__(*args, **kargs)
        """
        流式转置卷积实现。
        默认 kernel_size = [T_size, F_size]
        默认 stride = [T_stride, F_stride] 且 T_stride == 1
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        if type(kernel_size) is int:
            self.T_size = kernel_size
            self.F_size = kernel_size
        elif type(kernel_size) in [list, tuple]:
            self.T_size, self.F_size = kernel_size
        else:
            raise ValueError('Invalid kernel size.')

        if type(stride) is int:
            self.T_stride = stride
            self.F_stride = stride
        elif type(stride) in [list, tuple]:
            self.T_stride, self.F_stride = stride
        else:
            raise ValueError('Invalid stride size.')

        assert self.T_stride == 1 

        # 我们使用权重时间反向的Conv2d实现转置卷积    
        self.Conv2d = nn.Conv2d(in_channels = in_channels, 
                                out_channels = out_channels,
                                kernel_size = kernel_size,
                                stride = (self.T_stride, 1), # F维度stride不为1,将在forward中使用额外的上采样算子
                                padding = padding,
                                dilation = dilation,
                                groups = groups,
                                bias = bias)
    @staticmethod
    def get_indices(inp, F_stride):
        """
        根据 input 的维度和 F维度上采样维度得到上采样之后的维度
        inp: [bs,C,T,F]
        return: 
            indices: [bs,C,T,F]
        由于只对F上采样,因此输出的维度为 [bs,C,T,F_out]
        其中F_out = (F - 1) * (F_stride - 1) + F, 即向原来的每一个元素里面插入F_stride-1个零
        """
        bs,C,T,F = inp.shape 
        # indices: [bs,C,T,F]
        F_out = (F - 1) * (F_stride - 1) + F
        indices = np.zeros([bs * 1 * T * F])
        index = 0
        for i in range(bs * 1 * T * F):
            indices[i] = index
            if (i + 1) % F == 0:
                index += 1
            else:
                index += F_stride 
        indices = torch.from_numpy(np.repeat(indices.reshape([bs,1,T,F]).astype('int64'), C, axis = 1))
        return indices, F_out

    def forward(self, x, cache):
        """
        x: [bs,C,1,F]
        cache: [bs,C,T-1,F]
        """
        # [bs,C,T,F]
        inp = torch.cat([cache,x], dim = 2)
        out_cache = inp[:,:,1:]
        bs,C,T,F = inp.shape
        #添加上采样算子
        if self.F_stride > 1: 
            # [bs,C,T,F] -> [bs,C,T,F,1] -> [bs,C,T,F,F_stride] -> [bs,C,T,F_out]
            inp = torch.concat([inp[:,:,:,:,None], torch.zeros([bs,C,T,F,self.F_stride-1])], dim = -1).reshape([bs,C,T,-1])
            left_pad = self.F_stride - 1
            if self.F_size > 1:
                if left_pad <= self.F_size - 1:
                    inp = torch.nn.functional.pad(inp, pad = [self.F_size - 1, self.F_size - 1 - left_pad, 0, 0])
                else:
                    inp = torch.nn.functional.pad(inp, pad = [self.F_size - 1, 0, 0, 0])[:,:,:,: - (left_pad - self.F_stride + 1)]
            else:
                inp = inp[:,:,:,:-left_pad]

        outp = self.Conv2d(inp)
        # 这里也可以输出x,把更新cache放到外面

        return outp, out_cache
Le-Xiaohuai-speech commented 1 year ago

一个例子

def init_stream_conv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()

    for k in Conv_dict:
        Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)

def init_stream_deconv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()
    for k in Conv_dict:
        if 'weight' in k:
            Stream_dict['Conv2d.' + k] = torch.flip(Conv_dict[k].permute([1,0,2,3]), dims = [-2,-1])
        else:
            Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)
if __name__ =="__main__":
    import time
    torch.random.seed()
    # test Conv2d Stream
    SC = StreamConv(1, 1, [2,3], [1,2],).eval()
    test_input = torch.randn([1,1,10,6])
    conv = SC.Conv2d        
    with torch.no_grad():
        # Non-Streaming
        test_out1 = conv(torch.nn.functional.pad(test_input,[0,0,1,0]))
        print("Non streaming")
        print(test_out1)
        cache = torch.zeros([1,1,1,6])
        # Streaming
        print("Streaming")
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

    # test Conv2dTranspose Stream
    SC = StreamConvTranspose(2, 1, [2,3], [1,2],).eval()       
    deconv = torch.nn.ConvTranspose2d(2, 1, [2,3], [1,2]).eval()
    # filp the weights
    init_stream_deconv_weights(SC, deconv)
    #SC.Conv2d.weight.data = torch.flip(deconv.weight.data.permute([1,0,2,3]), dims = [-2,-1])
    #SC.Conv2d.bias.data = deconv.bias.data

    test_input = torch.randn([1,2,10,3])
    with torch.no_grad():
        print("\nNon streaming Deconv")
        test_out1 = deconv(test_input)[:,:,:10]
        print(test_out1)
        print("Streaming Deconv")
        cache = torch.zeros([1,2,1,3])
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)
andyye1999 commented 1 year ago

一个例子

def init_stream_conv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()

    for k in Conv_dict:
        Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)

def init_stream_deconv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()
    for k in Conv_dict:
        if 'weight' in k:
            Stream_dict['Conv2d.' + k] = torch.flip(Conv_dict[k].permute([1,0,2,3]), dims = [-2,-1])
        else:
            Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)
if __name__ =="__main__":
    import time
    torch.random.seed()
    # test Conv2d Stream
    SC = StreamConv(1, 1, [2,3], [1,2],).eval()
    test_input = torch.randn([1,1,10,6])
    conv = SC.Conv2d        
    with torch.no_grad():
        # Non-Streaming
        test_out1 = conv(torch.nn.functional.pad(test_input,[0,0,1,0]))
        print("Non streaming")
        print(test_out1)
        cache = torch.zeros([1,1,1,6])
        # Streaming
        print("Streaming")
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

    # test Conv2dTranspose Stream
    SC = StreamConvTranspose(2, 1, [2,3], [1,2],).eval()       
    deconv = torch.nn.ConvTranspose2d(2, 1, [2,3], [1,2]).eval()
    # filp the weights
    init_stream_deconv_weights(SC, deconv)
    #SC.Conv2d.weight.data = torch.flip(deconv.weight.data.permute([1,0,2,3]), dims = [-2,-1])
    #SC.Conv2d.bias.data = deconv.bias.data

    test_input = torch.randn([1,2,10,3])
    with torch.no_grad():
        print("\nNon streaming Deconv")
        test_out1 = deconv(test_input)[:,:,:10]
        print(test_out1)
        print("Streaming Deconv")
        cache = torch.zeros([1,2,1,3])
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

非常感谢解答,向大佬学习,另外我对流式处理还是不太懂
您觉得我这个简单的想法可行吗? IMG_5131
另外我也看到谷歌推出的流式框架https://github.com/google-research/google-research/tree/master/kws_streaming

Le-Xiaohuai-speech commented 1 year ago

另外我对现在一些框架不是很了解。要实现流式,就需要把cache保存直到下一次推断,这件事至少onnx mnn是不支持的,它需要把cache输出session,再输入下一次推断的session,导致效率很低。某些厂自己优化的推断库应该支持吧

---原始邮件--- 发件人: @.> 发送时间: 2023年5月26日(周五) 晚上10:42 收件人: @.>; 抄送: @.**@.>; 主题: Re: [Le-Xiaohuai-speech/DPCRN_DNS3] 您好!打扰了,我使用torch复现您的网络时,在模拟实时处理和非实时处理上出现了较大的结果差异。 (Issue #15)

一个例子 def init_stream_conv_weights(Stream_Conv, Conv): Stream_dict = Stream_Conv.state_dict() Conv_dict = Conv.state_dict() for k in Conv_dict: Stream_dict['Conv2d.' + k] = Conv_dict[k] Stream_Conv.load_state_dict(Stream_dict) def init_stream_deconv_weights(Stream_Conv, Conv): Stream_dict = Stream_Conv.state_dict() Conv_dict = Conv.state_dict() for k in Conv_dict: if 'weight' in k: Stream_dict['Conv2d.' + k] = torch.flip(Conv_dict[k].permute([1,0,2,3]), dims = [-2,-1]) else: Stream_dict['Conv2d.' + k] = Conv_dict[k] Stream_Conv.load_state_dict(Stream_dict) if name =="main": import time torch.random.seed() # test Conv2d Stream SC = StreamConv(1, 1, [2,3], [1,2],).eval() test_input = torch.randn([1,1,10,6]) conv = SC.Conv2d with torch.no_grad(): # Non-Streaming test_out1 = conv(torch.nn.functional.pad(test_input,[0,0,1,0])) print("Non streaming") print(test_out1) cache = torch.zeros([1,1,1,6]) # Streaming print("Streaming") for i in range(10): out, cache = SC(test_input[:,:,i:i+1], cache) print(out) # test Conv2dTranspose Stream SC = StreamConvTranspose(2, 1, [2,3], [1,2],).eval() deconv = torch.nn.ConvTranspose2d(2, 1, [2,3], [1,2]).eval() # filp the weights init_stream_deconv_weights(SC, deconv) #SC.Conv2d.weight.data = torch.flip(deconv.weight.data.permute([1,0,2,3]), dims = [-2,-1]) #SC.Conv2d.bias.data = deconv.bias.data test_input = torch.randn([1,2,10,3]) with torch.no_grad(): print("\nNon streaming Deconv") test_out1 = deconv(test_input)[:,:,:10] print(test_out1) print("Streaming Deconv") cache = torch.zeros([1,2,1,3]) for i in range(10): out, cache = SC(test_input[:,:,i:i+1], cache) print(out)

非常感谢解答,向大佬学习,另外我对流式处理还是不太懂 您觉得我这个简单的想法可行吗?

另外我也看到谷歌推出的流式框架https://github.com/google-research/google-research/tree/master/kws_streaming

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

shinobuwz commented 1 year ago

您好,请问您解决pytorch卷积的缓存的问题了吗

流式模型转端侧部署 使用MNN、onnx的话,必须导出所有中间cache节点,下次推理时copy

andyye1999 commented 1 year ago

您好,请问您解决pytorch卷积的缓存的问题了吗

流式模型转端侧部署 使用MNN、onnx的话,必须导出所有中间cache节点,下次推理时copy

可以分享一下您修改后的代码吗,学习学习

andyye1999 commented 1 year ago

一个例子

def init_stream_conv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()

    for k in Conv_dict:
        Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)

def init_stream_deconv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()
    for k in Conv_dict:
        if 'weight' in k:
            Stream_dict['Conv2d.' + k] = torch.flip(Conv_dict[k].permute([1,0,2,3]), dims = [-2,-1])
        else:
            Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)
if __name__ =="__main__":
    import time
    torch.random.seed()
    # test Conv2d Stream
    SC = StreamConv(1, 1, [2,3], [1,2],).eval()
    test_input = torch.randn([1,1,10,6])
    conv = SC.Conv2d        
    with torch.no_grad():
        # Non-Streaming
        test_out1 = conv(torch.nn.functional.pad(test_input,[0,0,1,0]))
        print("Non streaming")
        print(test_out1)
        cache = torch.zeros([1,1,1,6])
        # Streaming
        print("Streaming")
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

    # test Conv2dTranspose Stream
    SC = StreamConvTranspose(2, 1, [2,3], [1,2],).eval()       
    deconv = torch.nn.ConvTranspose2d(2, 1, [2,3], [1,2]).eval()
    # filp the weights
    init_stream_deconv_weights(SC, deconv)
    #SC.Conv2d.weight.data = torch.flip(deconv.weight.data.permute([1,0,2,3]), dims = [-2,-1])
    #SC.Conv2d.bias.data = deconv.bias.data

    test_input = torch.randn([1,2,10,3])
    with torch.no_grad():
        print("\nNon streaming Deconv")
        test_out1 = deconv(test_input)[:,:,:10]
        print(test_out1)
        print("Streaming Deconv")
        cache = torch.zeros([1,2,1,3])
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

您好,我将转置卷积的F_stride改为1

SC = StreamConvTranspose(2, 1, [2, 3], [1, 1], ).eval()
    deconv = torch.nn.ConvTranspose2d(2, 1, [2, 3], [1, 1]).eval()

得到的结果不一样

andyye1999 commented 1 year ago

一个例子

def init_stream_conv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()

    for k in Conv_dict:
        Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)

def init_stream_deconv_weights(Stream_Conv, Conv):
    Stream_dict = Stream_Conv.state_dict()
    Conv_dict = Conv.state_dict()
    for k in Conv_dict:
        if 'weight' in k:
            Stream_dict['Conv2d.' + k] = torch.flip(Conv_dict[k].permute([1,0,2,3]), dims = [-2,-1])
        else:
            Stream_dict['Conv2d.' + k] = Conv_dict[k]
    Stream_Conv.load_state_dict(Stream_dict)
if __name__ =="__main__":
    import time
    torch.random.seed()
    # test Conv2d Stream
    SC = StreamConv(1, 1, [2,3], [1,2],).eval()
    test_input = torch.randn([1,1,10,6])
    conv = SC.Conv2d        
    with torch.no_grad():
        # Non-Streaming
        test_out1 = conv(torch.nn.functional.pad(test_input,[0,0,1,0]))
        print("Non streaming")
        print(test_out1)
        cache = torch.zeros([1,1,1,6])
        # Streaming
        print("Streaming")
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

    # test Conv2dTranspose Stream
    SC = StreamConvTranspose(2, 1, [2,3], [1,2],).eval()       
    deconv = torch.nn.ConvTranspose2d(2, 1, [2,3], [1,2]).eval()
    # filp the weights
    init_stream_deconv_weights(SC, deconv)
    #SC.Conv2d.weight.data = torch.flip(deconv.weight.data.permute([1,0,2,3]), dims = [-2,-1])
    #SC.Conv2d.bias.data = deconv.bias.data

    test_input = torch.randn([1,2,10,3])
    with torch.no_grad():
        print("\nNon streaming Deconv")
        test_out1 = deconv(test_input)[:,:,:10]
        print(test_out1)
        print("Streaming Deconv")
        cache = torch.zeros([1,2,1,3])
        for i in range(10):
            out, cache = SC(test_input[:,:,i:i+1], cache)
            print(out)

您好,我将转置卷积的F_stride改为1

SC = StreamConvTranspose(2, 1, [2, 3], [1, 1], ).eval()
    deconv = torch.nn.ConvTranspose2d(2, 1, [2, 3], [1, 1]).eval()

得到的结果不一样

我将if self.F_stride > 1:改为if self.F_stride >= 1: 得到的结果是一样的