Open yugeshav opened 2 years ago
I think the input tensor in sample code is one-frame feature. If you want to feed a wav into the model, the input dimension might be (B,4,frames,257), but I'm not sure. Please email me (cch_amos@qq.com) if you have any insight.
@yugeshav Hey! Any progress on this? I am also confused with the input shape
@amirpashamobinitehrani The input shape for 1D conv is: (T, C,F) (Time frames, Channels(4 features), Frequency bins).
Thanks for you reply. Interesting! Yes, I had some presumptions. What still remains a mystery to me is to inject batch dimension into the play.
(Batch, Time frames, Channels(4 features), Frequency bins)
Which I assume we should refrain from. Right? We are simply processing 4 different features of 1 audio file in (time-frame) steps. So the time-frame dimension is fulfilling Batch dimension's role.
Correct!Each frame is a data sample here. If you want to use the (Batch, Time, Features, Frequency) you should use 2D Convolution and set the filters’ dimension to (n, 1).
Hi,
I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1)
so when a batch is being used, the time_frames
axis will grow. Since this is assume to be the N
input of a nn.Conv1d
, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?
Thanks, Esteban
Hi,
I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape
(time_frames, features, fft_size // 2 + 1)
so when a batch is being used, thetime_frames
axis will grow. Since this is assume to be theN
input of ann.Conv1d
, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?Thanks, Esteban
Hi Esteban, I am able to train this model. yes, you are right.
Thanks @atabakp !
As a follow-up question: How are you obtaining the "demodulated phase"?
There are a few methods to do this, but I don't know what the Authors exactly mean. for example https://arxiv.org/pdf/1608.01953.pdf
But for my training, I used Log Magnitude and normalized real/imag as inputs.
I managed to implement the demodulated phase, using (log_magnitude, demod_real, demod_imag) as inputs to train the model. For some reasons, I am not witnessing the model successfully doing anything useful. It would be nice to get some insights regarding the implementations if any has made a promising progress on this!
Thanks once again @atabakp! I was thinking something similar:
One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.
Thanks once again @atabakp! I was thinking something similar:
- Use log magnitude (as in the paper)
- Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.
One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.
Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.
Thanks once again @atabakp! I was thinking something similar:
- Use log magnitude (as in the paper)
- Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.
One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.
Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.
Thanks a lot once again, @atabakp ! I'll report back my progress as I manage to allocate time for working on it
Section 3 of this paper also has some information about phase demodulation: https://www.isca-speech.org/archive_v0/Interspeech_2018/pdfs/1773.pdf
Hi again @atabakp ,
How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:
module | type | input_shape | output_shape |
---|---|---|---|
root | TRUNet | (1, 4, 257) | (1, 5, 257) |
down1 | StandardConv1d | (1, 4, 257) | (1, 64, 128) |
down1.StandardConv1d | Sequential | (1, 4, 257) | (1, 64, 128) |
down1.StandardConv1d.0 | Conv1d | (1, 4, 257) | (1, 64, 128) |
down1.StandardConv1d.1 | ReLU | (1, 64, 128) | (1, 64, 128) |
down2 | DepthwiseSeparableConv1d | (1, 64, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d | Sequential | (1, 64, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.0 | Conv1d | (1, 64, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.1 | BatchNorm1d | (1, 128, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.2 | ReLU | (1, 128, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.3 | Conv1d | (1, 128, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.4 | BatchNorm1d | (1, 128, 128) | (1, 128, 128) |
down2.DepthwiseSeparableConv1d.5 | ReLU | (1, 128, 128) | (1, 128, 128) |
down3 | DepthwiseSeparableConv1d | (1, 128, 128) | (1, 128, 64) |
down3.DepthwiseSeparableConv1d | Sequential | (1, 128, 128) | (1, 128, 64) |
down3.DepthwiseSeparableConv1d.0 | Conv1d | (1, 128, 128) | (1, 128, 128) |
down3.DepthwiseSeparableConv1d.1 | BatchNorm1d | (1, 128, 128) | (1, 128, 128) |
down3.DepthwiseSeparableConv1d.2 | ReLU | (1, 128, 128) | (1, 128, 128) |
down3.DepthwiseSeparableConv1d.3 | Conv1d | (1, 128, 128) | (1, 128, 64) |
down3.DepthwiseSeparableConv1d.4 | BatchNorm1d | (1, 128, 64) | (1, 128, 64) |
down3.DepthwiseSeparableConv1d.5 | ReLU | (1, 128, 64) | (1, 128, 64) |
down4 | DepthwiseSeparableConv1d | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d | Sequential | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.0 | Conv1d | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.1 | BatchNorm1d | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.2 | ReLU | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.3 | Conv1d | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.4 | BatchNorm1d | (1, 128, 64) | (1, 128, 64) |
down4.DepthwiseSeparableConv1d.5 | ReLU | (1, 128, 64) | (1, 128, 64) |
down5 | DepthwiseSeparableConv1d | (1, 128, 64) | (1, 128, 32) |
down5.DepthwiseSeparableConv1d | Sequential | (1, 128, 64) | (1, 128, 32) |
down5.DepthwiseSeparableConv1d.0 | Conv1d | (1, 128, 64) | (1, 128, 64) |
down5.DepthwiseSeparableConv1d.1 | BatchNorm1d | (1, 128, 64) | (1, 128, 64) |
down5.DepthwiseSeparableConv1d.2 | ReLU | (1, 128, 64) | (1, 128, 64) |
down5.DepthwiseSeparableConv1d.3 | Conv1d | (1, 128, 64) | (1, 128, 32) |
down5.DepthwiseSeparableConv1d.4 | BatchNorm1d | (1, 128, 32) | (1, 128, 32) |
down5.DepthwiseSeparableConv1d.5 | ReLU | (1, 128, 32) | (1, 128, 32) |
down6 | DepthwiseSeparableConv1d | (1, 128, 32) | (1, 128, 16) |
down6.DepthwiseSeparableConv1d | Sequential | (1, 128, 32) | (1, 128, 16) |
down6.DepthwiseSeparableConv1d.0 | Conv1d | (1, 128, 32) | (1, 128, 32) |
down6.DepthwiseSeparableConv1d.1 | BatchNorm1d | (1, 128, 32) | (1, 128, 32) |
down6.DepthwiseSeparableConv1d.2 | ReLU | (1, 128, 32) | (1, 128, 32) |
down6.DepthwiseSeparableConv1d.3 | Conv1d | (1, 128, 32) | (1, 128, 16) |
down6.DepthwiseSeparableConv1d.4 | BatchNorm1d | (1, 128, 16) | (1, 128, 16) |
down6.DepthwiseSeparableConv1d.5 | ReLU | (1, 128, 16) | (1, 128, 16) |
FGRU | GRUBlock | (1, 16, 128) | (1, 64, 16) |
FGRU.GRU | GRU | (1, 16, 128) | ((1, 16, 128), (2, 1, 64)) |
FGRU.conv | Sequential | (1, 128, 16) | (1, 64, 16) |
FGRU.conv.0 | Conv1d | (1, 128, 16) | (1, 64, 16) |
FGRU.conv.1 | BatchNorm1d | (1, 64, 16) | (1, 64, 16) |
FGRU.conv.2 | ReLU | (1, 64, 16) | (1, 64, 16) |
TGRU | GRUBlock | (1, 16, 64) | (1, 64, 16) |
TGRU.GRU | GRU | (1, 16, 64) | ((1, 16, 128), (1, 1, 128)) |
TGRU.conv | Sequential | (1, 128, 16) | (1, 64, 16) |
TGRU.conv.0 | Conv1d | (1, 128, 16) | (1, 64, 16) |
TGRU.conv.1 | BatchNorm1d | (1, 64, 16) | (1, 64, 16) |
TGRU.conv.2 | ReLU | (1, 64, 16) | (1, 64, 16) |
up1 | FirstTrCNN | (1, 64, 16) | (1, 64, 31) |
up1.FirstTrCNN | Sequential | (1, 64, 16) | (1, 64, 31) |
up1.FirstTrCNN.0 | Conv1d | (1, 64, 16) | (1, 64, 16) |
up1.FirstTrCNN.1 | BatchNorm1d | (1, 64, 16) | (1, 64, 16) |
up1.FirstTrCNN.2 | ReLU | (1, 64, 16) | (1, 64, 16) |
up1.FirstTrCNN.3 | ConvTranspose1d | (1, 64, 16) | (1, 64, 31) |
up1.FirstTrCNN.4 | BatchNorm1d | (1, 64, 31) | (1, 64, 31) |
up1.FirstTrCNN.5 | ReLU | (1, 64, 31) | (1, 64, 31) |
up2 | TrCNN | (1, 64, 31) | (1, 64, 65) |
up2.TrCNN | Sequential | (1, 192, 32) | (1, 64, 65) |
up2.TrCNN.0 | Conv1d | (1, 192, 32) | (1, 64, 32) |
up2.TrCNN.1 | BatchNorm1d | (1, 64, 32) | (1, 64, 32) |
up2.TrCNN.2 | ReLU | (1, 64, 32) | (1, 64, 32) |
up2.TrCNN.3 | ConvTranspose1d | (1, 64, 32) | (1, 64, 65) |
up2.TrCNN.4 | BatchNorm1d | (1, 64, 65) | (1, 64, 65) |
up2.TrCNN.5 | ReLU | (1, 64, 65) | (1, 64, 65) |
up3 | TrCNN | (1, 64, 65) | (1, 64, 66) |
up3.TrCNN | Sequential | (1, 192, 64) | (1, 64, 66) |
up3.TrCNN.0 | Conv1d | (1, 192, 64) | (1, 64, 64) |
up3.TrCNN.1 | BatchNorm1d | (1, 64, 64) | (1, 64, 64) |
up3.TrCNN.2 | ReLU | (1, 64, 64) | (1, 64, 64) |
up3.TrCNN.3 | ConvTranspose1d | (1, 64, 64) | (1, 64, 66) |
up3.TrCNN.4 | BatchNorm1d | (1, 64, 66) | (1, 64, 66) |
up3.TrCNN.5 | ReLU | (1, 64, 66) | (1, 64, 66) |
up4 | TrCNN | (1, 64, 66) | (1, 64, 129) |
up4.TrCNN | Sequential | (1, 192, 64) | (1, 64, 129) |
up4.TrCNN.0 | Conv1d | (1, 192, 64) | (1, 64, 64) |
up4.TrCNN.1 | BatchNorm1d | (1, 64, 64) | (1, 64, 64) |
up4.TrCNN.2 | ReLU | (1, 64, 64) | (1, 64, 64) |
up4.TrCNN.3 | ConvTranspose1d | (1, 64, 64) | (1, 64, 129) |
up4.TrCNN.4 | BatchNorm1d | (1, 64, 129) | (1, 64, 129) |
up4.TrCNN.5 | ReLU | (1, 64, 129) | (1, 64, 129) |
up5 | TrCNN | (1, 64, 129) | (1, 64, 130) |
up5.TrCNN | Sequential | (1, 192, 128) | (1, 64, 130) |
up5.TrCNN.0 | Conv1d | (1, 192, 128) | (1, 64, 128) |
up5.TrCNN.1 | BatchNorm1d | (1, 64, 128) | (1, 64, 128) |
up5.TrCNN.2 | ReLU | (1, 64, 128) | (1, 64, 128) |
up5.TrCNN.3 | ConvTranspose1d | (1, 64, 128) | (1, 64, 130) |
up5.TrCNN.4 | BatchNorm1d | (1, 64, 130) | (1, 64, 130) |
up5.TrCNN.5 | ReLU | (1, 64, 130) | (1, 64, 130) |
up6 | LastTrCNN | (1, 64, 130) | (1, 5, 257) |
up6.LastTrCNN | Sequential | (1, 128, 128) | (1, 5, 257) |
up6.LastTrCNN.0 | Conv1d | (1, 128, 128) | (1, 5, 128) |
up6.LastTrCNN.1 | BatchNorm1d | (1, 5, 128) | (1, 5, 128) |
up6.LastTrCNN.2 | ReLU | (1, 5, 128) | (1, 5, 128) |
up6.LastTrCNN.3 | ConvTranspose1d | (1, 5, 128) | (1, 5, 257) |
I also have a question about the TGRU along the same lines. According to the paper:
The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis
But then, the input to this layer is a (1, 16, 64)
and according to PyTorch's GRU documentation, when batch_first=True
, the 2nd dimension is the sequence length, which is the case here because batch_first
defaults to True
and is not changed when the TGRU
layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26
To my understanding (please correct me if I'm wrong), the TGRU
layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU
, but using a unidirectional layer. I assumed first that batch_first
should be set to False
in order to apply the nn.GRU
along the first dimension which is the original time dimension.
Hi again @atabakp ,
How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:
module type input_shape output_shape root TRUNet (1, 4, 257) (1, 5, 257) down1 StandardConv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128) down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128) down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128) down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32) down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32) down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16) down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16) FGRU GRUBlock (1, 16, 128) (1, 64, 16) FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64)) FGRU.conv Sequential (1, 128, 16) (1, 64, 16) FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) TGRU GRUBlock (1, 16, 64) (1, 64, 16) TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128)) TGRU.conv Sequential (1, 128, 16) (1, 64, 16) TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) up1 FirstTrCNN (1, 64, 16) (1, 64, 31) up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31) up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31) up2 TrCNN (1, 64, 31) (1, 64, 65) up2.TrCNN Sequential (1, 192, 32) (1, 64, 65) up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32) up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32) up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32) up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65) up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65) up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65) up3 TrCNN (1, 64, 65) (1, 64, 66) up3.TrCNN Sequential (1, 192, 64) (1, 64, 66) up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66) up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66) up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66) up4 TrCNN (1, 64, 66) (1, 64, 129) up4.TrCNN Sequential (1, 192, 64) (1, 64, 129) up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129) up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129) up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129) up5 TrCNN (1, 64, 129) (1, 64, 130) up5.TrCNN Sequential (1, 192, 128) (1, 64, 130) up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128) up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128) up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128) up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130) up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130) up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130) up6 LastTrCNN (1, 64, 130) (1, 5, 257) up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257) up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128) up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128) up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128) up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)
I answer myself about this one. The paper config listing for the decoder says:
DecoderConfig = {1-th: (3,2,64), 2-th: (5,2,64), 3-th: (3,1,64), 4-th: (5,2,64), 5-th: (3,1,64), 6-th: (5,2,10)}
where the last number is the number of channels, therefore you're right, they should be 10 instead.
I also have a question about the TGRU along the same lines. According to the paper:
The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis
But then, the input to this layer is a
(1, 16, 64)
and according to PyTorch's GRU documentation, whenbatch_first=True
, the 2nd dimension is the sequence length, which is the case here becausebatch_first
defaults toTrue
and is not changed when theTGRU
layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26To my understanding (please correct me if I'm wrong), the
TGRU
layer will not really aggregate information along the time axis, but will instead do a similar role than theFGRU
, but using a unidirectional layer. I assumed first thatbatch_first
should be set toFalse
in order to apply thenn.GRU
along the first dimension which is the original time dimension.
https://github.com/YangangCao/TRUNet/issues/4#issuecomment-1182544756
Hi @atabakp ,
Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase
is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase
? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.
# Control random seed
rand_seed = torch.manual_seed(0)
# Lets assume it has shape (1, 5, 257) (the expected output for a single source)
# Since the activation function is ReLU, values can be equal or greater
# than 0
x_features = torch.rand((1, 5, 257), dtype=torch.float32)
# Extract z_tf for target and residual
z_tf = x_features[:, 0:1, :]
z_tf_residual = x_features[:, 1:2, :]
# Extract phi
phi = x_features[:, 2:3, :]
# Estimate beta (due to softplus it will be one or greater)
beta = 1.0 + F.softplus(phi)
# Estimate sigmod of target and residual
sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)
# Estimate upper bound for beta
beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)
# Because of the absolute value in the denominator, the same upper bound
# can be applied to both betas
beta = torch.clip(beta, max=beta_upper_bound)
# Compute both target and residual masks using eq. (1)
mask_tf = beta * sigmoid_tf
mask_tf_residual = beta * sigmoid_tf_residual
# Now that we have both masks, let's compute the triangle cosine law
cos_phase = (
(1.0 + mask_tf.square() - mask_tf_residual.square())
/ (2.0 * mask_tf))
# Use trigonometric identity to obtain the sine
sin_phase = torch.sqrt(1.0 - cos_phase.square())
# Now estimate the sign
q0 = x_features[:, 3:4, :]
q1 = x_features[:, 4:5, :]
gamma0 = F.gumbel_softmax(q0, tau=1.0)
gamma1 = F.gumbel_softmax(q1, tau=1.0)
sign = torch.where(gamma0 > gamma1, -1.0, 1.0)
# Finally, estimate the complex mask
complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)
# Then it should be applied to the stft and inverted using the istft
...
Hi @atabakp ,
Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the
cos_phase
is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respectivesin_phase
? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.# Control random seed rand_seed = torch.manual_seed(0) # Lets assume it has shape (1, 5, 257) (the expected output for a single source) # Since the activation function is ReLU, values can be equal or greater # than 0 x_features = torch.rand((1, 5, 257), dtype=torch.float32) # Extract z_tf for target and residual z_tf = x_features[:, 0:1, :] z_tf_residual = x_features[:, 1:2, :] # Extract phi phi = x_features[:, 2:3, :] # Estimate beta (due to softplus it will be one or greater) beta = 1.0 + F.softplus(phi) # Estimate sigmod of target and residual sigmoid_tf = F.sigmoid(z_tf - z_tf_residual) sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf) # Estimate upper bound for beta beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual) # Because of the absolute value in the denominator, the same upper bound # can be applied to both betas beta = torch.clip(beta, max=beta_upper_bound) # Compute both target and residual masks using eq. (1) mask_tf = beta * sigmoid_tf mask_tf_residual = beta * sigmoid_tf_residual # Now that we have both masks, let's compute the triangle cosine law cos_phase = ( (1.0 + mask_tf.square() - mask_tf_residual.square()) / (2.0 * mask_tf)) # Use trigonometric identity to obtain the sine sin_phase = torch.sqrt(1.0 - cos_phase.square()) # Now estimate the sign q0 = x_features[:, 3:4, :] q1 = x_features[:, 4:5, :] gamma0 = F.gumbel_softmax(q0, tau=1.0) gamma1 = F.gumbel_softmax(q1, tau=1.0) sign = torch.where(gamma0 > gamma1, -1.0, 1.0) # Finally, estimate the complex mask complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase) # Then it should be applied to the stft and inverted using the istft ...
1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))
5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]
sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)
Hi @atabakp , Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the
cos_phase
is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respectivesin_phase
? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.# Control random seed rand_seed = torch.manual_seed(0) # Lets assume it has shape (1, 5, 257) (the expected output for a single source) # Since the activation function is ReLU, values can be equal or greater # than 0 x_features = torch.rand((1, 5, 257), dtype=torch.float32) # Extract z_tf for target and residual z_tf = x_features[:, 0:1, :] z_tf_residual = x_features[:, 1:2, :] # Extract phi phi = x_features[:, 2:3, :] # Estimate beta (due to softplus it will be one or greater) beta = 1.0 + F.softplus(phi) # Estimate sigmod of target and residual sigmoid_tf = F.sigmoid(z_tf - z_tf_residual) sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf) # Estimate upper bound for beta beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual) # Because of the absolute value in the denominator, the same upper bound # can be applied to both betas beta = torch.clip(beta, max=beta_upper_bound) # Compute both target and residual masks using eq. (1) mask_tf = beta * sigmoid_tf mask_tf_residual = beta * sigmoid_tf_residual # Now that we have both masks, let's compute the triangle cosine law cos_phase = ( (1.0 + mask_tf.square() - mask_tf_residual.square()) / (2.0 * mask_tf)) # Use trigonometric identity to obtain the sine sin_phase = torch.sqrt(1.0 - cos_phase.square()) # Now estimate the sign q0 = x_features[:, 3:4, :] q1 = x_features[:, 4:5, :] gamma0 = F.gumbel_softmax(q0, tau=1.0) gamma1 = F.gumbel_softmax(q1, tau=1.0) sign = torch.where(gamma0 > gamma1, -1.0, 1.0) # Finally, estimate the complex mask complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase) # Then it should be applied to the stft and inverted using the istft ...
1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))
5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]
sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)
Thanks you very much @atabakp !
sigmoid_tf_residual
and it works with the simpler version.Hi again @atabakp ,
When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?
I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.
Hi again @atabakp ,
When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?
I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.
I am using random-length sequences, single sequence per iteration (batch size =1)
Sorry for necroposting here, but I'm trying to train this model, and with no luck yet. I managed to add trainable PCEN (as described in paper) and training on spectrograms. I construct input feature from PCEN (output of trainable layer), log magnitude, real and imag parts of STFT and feed it to the rest of the model described here. I also implemented 2d convs since I wanted to train on batches. Losses are the same as in the paper - multires cosine similarity + multires spectrum MSE. Model trains very weirdly (loss is decreasing for the first couple of hundreds of steps, then increasing, then decreasing again). @atabakp @eagomez2 I assume you managed to train this model - can you evaluate, what did you change compared to paper? Or maybe share your pipeline. Thanks in advance.
Hi @JBloodless ,
In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).
Hi @JBloodless ,
In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).
Can you elaborate a bit, what do you mean by GAN loss?
Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf
Hi @JBloodless , In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).
Can you elaborate a bit, what do you mean by GAN loss?
Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
What do you mean by repeating? I thought that network (in this implementation) returns one set of features for PHM (time, 5, bins), and corresponding PHM will be mask for direct source. Since I need to obtain only direct source (clean speech), I just multiply this PHM with input spectrum, and I get clean output. What did I assume wrong?
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?
Yes, that's correct
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?
Yes, that's correct
I think I got it wrong. It seems that your function technically calculates pair of masks
# Compute both target and residual masks using eq. (1)
mask_tf = beta * sigmoid_tf
mask_tf_residual = beta * sigmoid_tf_residual
and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?
Yes, that's correct
I think I got it wrong. It seems that your function technically calculates pair of masks
# Compute both target and residual masks using eq. (1) mask_tf = beta * sigmoid_tf mask_tf_residual = beta * sigmoid_tf_residual
and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.
Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?
If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results
@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?
Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform
Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?
Yes, that's correct
I think I got it wrong. It seems that your function technically calculates pair of masks
# Compute both target and residual masks using eq. (1) mask_tf = beta * sigmoid_tf mask_tf_residual = beta * sigmoid_tf_residual
and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.
Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?
If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results
It's not exploding, loss just stable around some value. Yes, I'm trying to use paper loss (since this loss looks right for me and I don't see why it shouldn't work).
Purple one is the latest try (with 10 channels output)
I already fixed a bunch on NaNs, so I think that zeros handling is not a problem here. Maybe I'm interpreting outputs wrong, I'll try to log what's going on.
@eagomez2 maybe you'll help me double check output interpretation.
x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :]) # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask
result = x * mask_direct # x is the input - complex spectrogram
out_wave = torch.istft(result,
n_fft=self.nfft,
hop_length=self.hop,
onesided=True,
window=self.window.to(wave.device),
center=True)
Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])
@eagomez2 maybe you'll help me double check output interpretation.
x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5 mask_direct = calculate_PHM(x16[:, :5, :]) # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask result = x * mask_direct # x is the input - complex spectrogram out_wave = torch.istft(result, n_fft=self.nfft, hop_length=self.hop, onesided=True, window=self.window.to(wave.device), center=True)
Is this the same-ish as yours? My main concern for now is
mask_direct = calculate_PHM(x16[:, :5, :])
To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.
@eagomez2 maybe you'll help me double check output interpretation.
x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5 mask_direct = calculate_PHM(x16[:, :5, :]) # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask result = x * mask_direct # x is the input - complex spectrogram out_wave = torch.istft(result, n_fft=self.nfft, hop_length=self.hop, onesided=True, window=self.window.to(wave.device), center=True)
Is this the same-ish as yours? My main concern for now is
mask_direct = calculate_PHM(x16[:, :5, :])
To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.
By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.
@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?
@eagomez2 maybe you'll help me double check output interpretation.
x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5 mask_direct = calculate_PHM(x16[:, :5, :]) # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask result = x * mask_direct # x is the input - complex spectrogram out_wave = torch.istft(result, n_fft=self.nfft, hop_length=self.hop, onesided=True, window=self.window.to(wave.device), center=True)
Is this the same-ish as yours? My main concern for now is
mask_direct = calculate_PHM(x16[:, :5, :])
To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.
By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.
single channels can also produce the clean; you only need to multiply the output mask(bounded to 0,1 with a sigmoid) with the magnitude of the noisy signal(x) and use the noisy phase(x.angle()) to construct the out_wave
@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?
the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM for now and make sure the rest of the code is ok.
@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?
the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM for now and make sure the rest of the code is ok.
Seems fine to me. Maybe the problem really is PHM computation.
For now I settled with
mask_direct = calculate_PHM(x16[:, :5, :])
result_direct = torch.view_as_complex(x) * mask_direct.squeeze(1)
mask_nonnoise = calculate_PHM(x16[:, 5:, :])
result_nonnoise = torch.view_as_complex(x) * mask_nonnoise.squeeze(1)
result_noise = torch.view_as_complex(x) - result_nonnoise
mask_revpath = mask_nonnoise - mask_direct
result_revpath = torch.view_as_complex(x) * mask_revpath.squeeze(1)
Since you mentiond that order doesn't matter, I assumed that in the second pair non-noise will be first, so I'm directly calculating mask for direct path and non-noise signal, and then obtaining reveberation mask as in fig.2 of the paper.
So is it working now @JBloodless ?
I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps
to stabilize the cos_phase
term and it worked. Also sigmoid_tf_residual
can be simplified to 1.0 - sigmoid_tf_target
.
All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.
So is it working now @JBloodless ?
I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some
eps
to stabilize thecos_phase
term and it worked. Alsosigmoid_tf_residual
can be simplified to1.0 - sigmoid_tf_target
.All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.
I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.
So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some
eps
to stabilize thecos_phase
term and it worked. Alsosigmoid_tf_residual
can be simplified to1.0 - sigmoid_tf_target
. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.
@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.
softmax with temperature,
Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.
One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?
softmax with temperature,
Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.
One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?
I tried different variations, but I found out that only using loss on direct is good enough.
So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some
eps
to stabilize thecos_phase
term and it worked. Alsosigmoid_tf_residual
can be simplified to1.0 - sigmoid_tf_target
. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.
@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.
Yes, I tried, I ended up using a single channel for masking the magnitude.
So is it working now @JBloodless ? I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some
eps
to stabilize thecos_phase
term and it worked. Alsosigmoid_tf_residual
can be simplified to1.0 - sigmoid_tf_target
. All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.
@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.
Yes, I tried, I ended up using a single channel for masking the magnitude.
I've managed to train the model without PHM and with single loss on direct (same as paper, multires cosine similarity + multires spectrum MSE). The model converges:
but the result is strange
Maybe it's because of PCEN feature (my implementation may be not ideal), but voice harmonics in the spectrum seem to be "dereverbed", so I'll try to locate the reason of noisiness.
@atabakp did you extract magnitude of input spectrum as torch.abs() and not as torch.real?
Hi
As per the paper, 4 features must be concatenated as input to TRUNET,
so the input will become (Batchsize, 4 features, No.of frames in STFT, No.of STFT bins) , so it is a 4 dimesional one
But in the sample code you are showing input as 3 dimension (1,4,257), since first layer is conv1d
I'm confused whether the input to TRUNET is 3 dimension or 4 dimension ?
Regards Yugesh