YangangCao / TRUNet

unofficial PyTorch implementation of 《REAL-TIME DENOISING AND DEREVERBERATION WTIH TINY RECURRENT U-NET》
88 stars 20 forks source link

Input Feature to TRUNET #5

Open yugeshav opened 2 years ago

yugeshav commented 2 years ago

Hi

As per the paper, 4 features must be concatenated as input to TRUNET,

  1. log spectrum
  2. PCEN
  3. real part of demodulated phase
  4. imaginary part of demodulated phase

so the input will become (Batchsize, 4 features, No.of frames in STFT, No.of STFT bins) , so it is a 4 dimesional one

But in the sample code you are showing input as 3 dimension (1,4,257), since first layer is conv1d

I'm confused whether the input to TRUNET is 3 dimension or 4 dimension ?

Regards Yugesh

AmosCch commented 2 years ago

I think the input tensor in sample code is one-frame feature. If you want to feed a wav into the model, the input dimension might be (B,4,frames,257), but I'm not sure. Please email me (cch_amos@qq.com) if you have any insight.

amirpashamobinitehrani commented 1 year ago

@yugeshav Hey! Any progress on this? I am also confused with the input shape

atabakp commented 1 year ago

@amirpashamobinitehrani The input shape for 1D conv is: (T, C,F) (Time frames, Channels(4 features), Frequency bins).

amirpashamobinitehrani commented 1 year ago

Thanks for you reply. Interesting! Yes, I had some presumptions. What still remains a mystery to me is to inject batch dimension into the play.

(Batch, Time frames, Channels(4 features), Frequency bins)

Which I assume we should refrain from. Right? We are simply processing 4 different features of 1 audio file in (time-frame) steps. So the time-frame dimension is fulfilling Batch dimension's role.

atabakp commented 1 year ago

Correct!Each frame is a data sample here. If you want to use the (Batch, Time, Features, Frequency) you should use 2D Convolution and set the filters’ dimension to (n, 1).

eagomez2 commented 1 year ago

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks, Esteban

atabakp commented 1 year ago

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks, Esteban

Hi Esteban, I am able to train this model. yes, you are right.

eagomez2 commented 1 year ago

Thanks @atabakp !

As a follow-up question: How are you obtaining the "demodulated phase"?

atabakp commented 1 year ago

There are a few methods to do this, but I don't know what the Authors exactly mean. for example https://arxiv.org/pdf/1608.01953.pdf

But for my training, I used Log Magnitude and normalized real/imag as inputs.

amirpashamobinitehrani commented 1 year ago

I managed to implement the demodulated phase, using (log_magnitude, demod_real, demod_imag) as inputs to train the model. For some reasons, I am not witnessing the model successfully doing anything useful. It would be nice to get some insights regarding the implementations if any has made a promising progress on this!

eagomez2 commented 1 year ago

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

atabakp commented 1 year ago

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

eagomez2 commented 1 year ago

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

Thanks a lot once again, @atabakp ! I'll report back my progress as I manage to allocate time for working on it

atabakp commented 1 year ago

Section 3 of this paper also has some information about phase demodulation: https://www.isca-speech.org/archive_v0/Interspeech_2018/pdfs/1773.pdf

eagomez2 commented 1 year ago

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape
root TRUNet (1, 4, 257) (1, 5, 257)
down1 StandardConv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128)
down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128)
down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32)
down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32)
down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16)
down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16)
FGRU GRUBlock (1, 16, 128) (1, 64, 16)
FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64))
FGRU.conv Sequential (1, 128, 16) (1, 64, 16)
FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
TGRU GRUBlock (1, 16, 64) (1, 64, 16)
TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128))
TGRU.conv Sequential (1, 128, 16) (1, 64, 16)
TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
up1 FirstTrCNN (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31)
up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31)
up2 TrCNN (1, 64, 31) (1, 64, 65)
up2.TrCNN Sequential (1, 192, 32) (1, 64, 65)
up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32)
up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32)
up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32)
up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65)
up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65)
up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65)
up3 TrCNN (1, 64, 65) (1, 64, 66)
up3.TrCNN Sequential (1, 192, 64) (1, 64, 66)
up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66)
up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66)
up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66)
up4 TrCNN (1, 64, 66) (1, 64, 129)
up4.TrCNN Sequential (1, 192, 64) (1, 64, 129)
up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129)
up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129)
up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129)
up5 TrCNN (1, 64, 129) (1, 64, 130)
up5.TrCNN Sequential (1, 192, 128) (1, 64, 130)
up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128)
up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128)
up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128)
up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130)
up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130)
up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130)
up6 LastTrCNN (1, 64, 130) (1, 5, 257)
up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257)
up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128)
up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)
eagomez2 commented 1 year ago

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

eagomez2 commented 1 year ago

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape root TRUNet (1, 4, 257) (1, 5, 257) down1 StandardConv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128) down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128) down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128) down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32) down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32) down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16) down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16) FGRU GRUBlock (1, 16, 128) (1, 64, 16) FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64)) FGRU.conv Sequential (1, 128, 16) (1, 64, 16) FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) TGRU GRUBlock (1, 16, 64) (1, 64, 16) TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128)) TGRU.conv Sequential (1, 128, 16) (1, 64, 16) TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) up1 FirstTrCNN (1, 64, 16) (1, 64, 31) up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31) up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31) up2 TrCNN (1, 64, 31) (1, 64, 65) up2.TrCNN Sequential (1, 192, 32) (1, 64, 65) up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32) up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32) up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32) up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65) up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65) up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65) up3 TrCNN (1, 64, 65) (1, 64, 66) up3.TrCNN Sequential (1, 192, 64) (1, 64, 66) up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66) up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66) up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66) up4 TrCNN (1, 64, 66) (1, 64, 129) up4.TrCNN Sequential (1, 192, 64) (1, 64, 129) up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129) up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129) up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129) up5 TrCNN (1, 64, 129) (1, 64, 130) up5.TrCNN Sequential (1, 192, 128) (1, 64, 130) up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128) up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128) up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128) up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130) up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130) up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130) up6 LastTrCNN (1, 64, 130) (1, 5, 257) up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257) up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128) up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128) up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128) up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)

I answer myself about this one. The paper config listing for the decoder says:

DecoderConfig = {1-th: (3,2,64), 2-th: (5,2,64), 3-th: (3,1,64), 4-th: (5,2,64), 5-th: (3,1,64), 6-th: (5,2,10)}

where the last number is the number of channels, therefore you're right, they should be 10 instead.

atabakp commented 1 year ago

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

https://github.com/YangangCao/TRUNet/issues/4#issuecomment-1182544756

eagomez2 commented 1 year ago

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)

    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))

    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...
atabakp commented 1 year ago

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)

    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))

    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

eagomez2 commented 1 year ago

Hi @atabakp , Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)

    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))

    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

Thanks you very much @atabakp !

eagomez2 commented 1 year ago

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

atabakp commented 1 year ago

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

I am using random-length sequences, single sequence per iteration (batch size =1)

JBloodless commented 7 months ago

Sorry for necroposting here, but I'm trying to train this model, and with no luck yet. I managed to add trainable PCEN (as described in paper) and training on spectrograms. I construct input feature from PCEN (output of trainable layer), log magnitude, real and imag parts of STFT and feed it to the rest of the model described here. I also implemented 2d convs since I wanted to train on batches. Losses are the same as in the paper - multires cosine similarity + multires spectrum MSE. Model trains very weirdly (loss is decreasing for the first couple of hundreds of steps, then increasing, then decreasing again). @atabakp @eagomez2 I assume you managed to train this model - can you evaluate, what did you change compared to paper? Or maybe share your pipeline. Thanks in advance.

eagomez2 commented 7 months ago

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

JBloodless commented 7 months ago

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

eagomez2 commented 7 months ago

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

Hi @JBloodless , In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

JBloodless commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

eagomez2 commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

JBloodless commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

What do you mean by repeating? I thought that network (in this implementation) returns one set of features for PHM (time, 5, bins), and corresponding PHM will be mask for direct source. Since I need to obtain only direct source (clean speech), I just multiply this PHM with input spectrum, and I get clean output. What did I assume wrong?

JBloodless commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

eagomez2 commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

JBloodless commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

eagomez2 commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?

If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results

JBloodless commented 6 months ago

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?

If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results

It's not exploding, loss just stable around some value. Yes, I'm trying to use paper loss (since this loss looks right for me and I don't see why it shouldn't work).

Снимок экрана 2024-02-13 в 13 11 58

Purple one is the latest try (with 10 channels output)

I already fixed a bunch on NaNs, so I think that zeros handling is not a problem here. Maybe I'm interpreting outputs wrong, I'll try to log what's going on.

JBloodless commented 6 months ago

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

atabakp commented 6 months ago

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.

JBloodless commented 6 months ago

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.

By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.

JBloodless commented 6 months ago

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

atabakp commented 6 months ago

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.

By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.

single channels can also produce the clean; you only need to multiply the output mask(bounded to 0,1 with a sigmoid) with the magnitude of the noisy signal(x) and use the noisy phase(x.angle()) to construct the out_wave

atabakp commented 6 months ago

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM for now and make sure the rest of the code is ok.

JBloodless commented 6 months ago

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM for now and make sure the rest of the code is ok.

Снимок экрана 2024-02-15 в 16 39 08

Seems fine to me. Maybe the problem really is PHM computation.

For now I settled with

mask_direct = calculate_PHM(x16[:, :5, :])
result_direct = torch.view_as_complex(x) * mask_direct.squeeze(1)

mask_nonnoise = calculate_PHM(x16[:, 5:, :])  
result_nonnoise = torch.view_as_complex(x) * mask_nonnoise.squeeze(1)
result_noise = torch.view_as_complex(x) - result_nonnoise

mask_revpath = mask_nonnoise - mask_direct
result_revpath = torch.view_as_complex(x) * mask_revpath.squeeze(1)

Since you mentiond that order doesn't matter, I assumed that in the second pair non-noise will be first, so I'm directly calculating mask for direct path and non-noise signal, and then obtaining reveberation mask as in fig.2 of the paper.

eagomez2 commented 6 months ago

So is it working now @JBloodless ?

I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.

All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

atabakp commented 6 months ago

So is it working now @JBloodless ?

I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.

All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

eagomez2 commented 6 months ago

So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

JBloodless commented 6 months ago

softmax with temperature,

Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.

One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?

atabakp commented 6 months ago

softmax with temperature,

Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.

One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?

I tried different variations, but I found out that only using loss on direct is good enough.

atabakp commented 6 months ago

So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

Yes, I tried, I ended up using a single channel for masking the magnitude.

JBloodless commented 6 months ago

So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

Yes, I tried, I ended up using a single channel for masking the magnitude.

I've managed to train the model without PHM and with single loss on direct (same as paper, multires cosine similarity + multires spectrum MSE). The model converges:

Снимок экрана 2024-02-19 в 17 18 27

but the result is strange

Снимок экрана 2024-02-19 в 17 19 06

Maybe it's because of PCEN feature (my implementation may be not ideal), but voice harmonics in the spectrum seem to be "dereverbed", so I'll try to locate the reason of noisiness.

JBloodless commented 6 months ago

@atabakp did you extract magnitude of input spectrum as torch.abs() and not as torch.real?