yxlu-0102 / MP-SENet

MP-SENet: A Speech Enhancement Model with Parallel Denoising of Magnitude and Phase Spectra
MIT License
267 stars 40 forks source link

Incoherent dimensions in the self-attention module #21

Open Rodolphe2005 opened 4 months ago

Rodolphe2005 commented 4 months ago

Thank you for sharing your very interesting work.

I have a question about the self-attention used in the conformer block. Before applying the time_conformer, you reshape the tensor to a $(b \times f, t, c)$ shape :

https://github.com/yxlu-0102/MP-SENet/blob/faed99c08b1a1325f042613b36331bd96f046712/models/generator.py#L110

Then, next line, you apply the time conformer :

https://github.com/yxlu-0102/MP-SENet/blob/faed99c08b1a1325f042613b36331bd96f046712/models/generator.py#L111

In the conformer block, you use a MultiHeadAttention to compute the self-attention. However, this pytorch module is initialized with the batch_first=False parameter (because it's the default paramater) :

https://github.com/yxlu-0102/MP-SENet/blob/faed99c08b1a1325f042613b36331bd96f046712/models/conformer.py#L49

So, the self-attention module is expecting a shape of $(L, N, E)$ where $L$ is the sequence length, $N$ is the batch size and $E$ is the dimension (as explained in the pytorch documentation here : https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html )

As the tensor x is of shape $(b \times f, t, c)$, it means that the self-attention will process $b \times f$ as the sequence length instead of using $t$. It would make more sense to initialize the MultiHeadAttention with the parameter batch_first=True. However, when I tried that, the results are not good.

Can you explain ?

Thank you very much

yxlu-0102 commented 4 months ago

You are right, it was my oversight when writing the code. Similar concerns have been raised to me before, but this approach seems to yield positive results now.

In this case, it's still a self-attention along either the frequency or time axes, except that the sequences of N batches have been concatenated together.

It seems quite strained to explain the positive impact of this approach. If you or anyone else have any thoughts on it, please share them here.

yxlu-0102 commented 4 months ago

You are right, it was my oversight when writing the code. Similar concerns have been raised to me before, but this approach seems to yield positive results now.

In this case, it's still a self-attention along either the frequency or time axes, except that the sequences of N batches have been concatenated together.

It seems quite strained to explain the positive impact of this approach. If you or anyone else have any thoughts on it, please share them here.

yxlu-0102 commented 3 months ago

Hello, I followed your advice and set batch_first = True in our latest model's Attention module (corresponding to the red curve in the graph).

However, I found that the results are almost the same as when batch_first was set to False (corresponding to the orange curve in the graph), and there has been no deterioration.

截屏2024-04-05 09 13 15

We give the optimal experimental results here.

  1. batch_first=False: pesq: 3.60, csig: 4.81, cbak: 3.99, covl: 4.34, stoi: 0.96
  2. batch_first=True: pesq: 3.60, csig: 4.80, cbak: 3.99, covl: 4.33, stoi: 0.96
RobertSJTU commented 3 months ago

I think when you set the bach_first=False, the T-conformer actually serve as a (concat-T) F-conformer, and the F-conformer serve as (concat-F) T-conformer as the configuration of batch_first=True. So the main difference between the two configuration is the order of T-F-conformer, I think batch info may not contribute too much in the self-att module, which means the key matrix WK controls the weights.