Xiaobin-Rong / deepvqe

An unofficial implementation of DeepVQE proposed by Microsoft Corp.
62 stars 19 forks source link

Concern about Dimension Mismatch in AlignBlock: Input 128, Output 256, Subsequent Layer Design for 128? #4

Open lin-sh opened 10 months ago

lin-sh commented 10 months ago

I have a question I'd like to discuss with the author and the community. In the AlignBlock, the paper states that the input and output should be the same, for example, 128 input and 128 output. However, when followed by a concatenation operation, the dimension becomes 256. In the subsequent layer, if it is designed for 128, how should this be handled? I'm not sure if my understanding is correct.

Xiaobin-Rong commented 10 months ago

I'm not sure if what you mentioned refers to the sentence in the paper: "The first microphone encoding block has 64 filters and all the following microphone encoding blocks have 128 filters." Based on my understanding, this indicates that the output channel number is 128 for each block after the first block, regardless of the input channel number. I would like to state that due to the passage of a long time, I am uncertain if my current understanding is correct.

I have previously reproduced DeepVQE with an Align Block, but I haven't trained it. It is important to note that the parameter count of my reproduced model is 8 M, which does not match the provided parameter count of 7.5 M in the paper. Here is the code I have written, I hope it will be helpful to you.

class DeepVQE_Align(nn.Module):
    def __init__(self):
        super().__init__()
        self.fe = FE()

        self.enblock1_far = EncoderBlock(2, 32)
        self.enblock2_far = EncoderBlock(32, 128)
        self.alignblock = AlignBlock(128, 128, delay=100)

        self.enblock1 = EncoderBlock(2, 64)
        self.enblock2 = EncoderBlock(64, 128)
        self.enblock3 = EncoderBlock(256, 128)
        self.enblock4 = EncoderBlock(128, 128)
        self.enblock5 = EncoderBlock(128, 128)

        self.bottle = Bottleneck(128*9, 64*9)

        self.deblock5 = DecoderBlock(128, 128)
        self.deblock4 = DecoderBlock(128, 128)
        self.deblock3 = DecoderBlock(128, 128)
        self.deblock2 = DecoderBlock(128, 64)
        self.deblock1 = DecoderBlock(64, 27)
        self.ccm = CCM()

    def forward(self, x):
        """x: (B,2,F,T,2)"""

        x_mic, x_far = x[:,0], x[:,1]

        en_x0_far = self.fe(x_far)
        en_x1_far = self.enblock1_far(en_x0_far)
        en_x2_far = self.enblock2_far(en_x1_far)

        en_x0 = self.fe(x_mic)        #; print(en_x0.shape)
        en_x1 = self.enblock1(en_x0)  #; print(en_x1.shape)
        en_x2 = self.enblock2(en_x1)  #; print(en_x2.shape)

        align_x2_far = self.alignblock(en_x2, en_x2_far)
        align_x2 = torch.cat([en_x2, align_x2_far], dim=1)

        en_x3 = self.enblock3(align_x2)  #; print(en_x3.shape)
        en_x4 = self.enblock4(en_x3)     #; print(en_x4.shape)
        en_x5 = self.enblock5(en_x4)     #; print(en_x5.shape)

        en_xr = self.bottle(en_x5)       #; print(en_xr.shape)

        de_x5 = self.deblock5(en_xr, en_x5)[..., :en_x4.shape[-1]]  #; print(de_x5.shape)
        de_x4 = self.deblock4(de_x5, en_x4)[..., :en_x3.shape[-1]]  #; print(de_x4.shape)
        de_x3 = self.deblock3(de_x4, en_x3)[..., :en_x2.shape[-1]]  #; print(de_x3.shape)
        de_x2 = self.deblock2(de_x3, en_x2)[..., :en_x1.shape[-1]]  #; print(de_x2.shape)
        de_x1 = self.deblock1(de_x2, en_x1)[..., :en_x0.shape[-1]]  #; print(de_x1.shape)

        x_enh = self.ccm(de_x1, x_mic)  # (B,F,T,2)

        return x_enh
lin-sh commented 10 months ago

Thank you for your code; it has cleared up my confusion. I am currently looking at your another repository on enhanced training. I'm wondering if you would be willing to open-source the training code for deepvqe? Regardless of your convenience, I am very grateful because your other repository has also provided inspiration, and the code quality there is excellent.

Xiaobin-Rong commented 10 months ago

Thank you for your attention and praise for my code. As I mentioned in the README, my focus in reproducing DeepVQE was solely on its speech enhancement performance, so the training code is the same as the one in my repository called SEtrain.

lin-sh commented 10 months ago

Thank you for your repository. The training template is indeed concise and efficient. Looking forward to more of your open-source projects!

Xiaobin-Rong commented 10 months ago

You're welcome! I'm glad you found the training template useful. I really appreciate your kind words. Thanks again for your support!

Ryanzlay commented 9 months ago

你们好,我想问问为什么align能对齐时延嘛,具体原理有点看不懂

jackyyigao commented 1 month ago

https://ristea.github.io/deep-vqe/, one of the author post the " align feature map" on this blog, it looks perfect. But it is not mentioned in the paper how the model is trained. do you think they train all the layers as a whole, or train the alignment related parts(mic encoder1 encoder2 + far encoder1 encoder2 + align block) first, then froze these parts and train the following layers? I'm not quite sure if I could get the perfect "align feature map" as posted by the author.