breizhn / DTLN

Tensorflow 2.x implementation of the DTLN real time speech denoising model. With TF-lite, ONNX and real-time audio processing support.
MIT License
567 stars 160 forks source link

Transfer Learning with DTLN model weights to remove the block shift #11

Closed hchintada closed 4 years ago

hchintada commented 4 years ago

To do away with block processing at inference, I'm trying to use your pre-trained weights, and retrain the network, after replacing the stftLayer with fftLayer.

Using mag normalization as is.

I can set stateful=False for both separation kernels while training?

Train for around 20 epochs.

Does this idea make sense @breizhn? Or should I start from zero to effect this change?

Also, can you please provide guidance on how to use data augmentation/pre-processing to train the network only with 40 hours data?

breizhn commented 4 years ago

Hi,

I don't know if I understand you correctly, you would like to get rid of the blockshift? If that is the case you have to retrain the model. But for training you still need the STFT layer, because you will probably train on sequences/ whole utterances.

While training stateful should be False, so the network reset states between batches.

It's actually pretty easy. The order of noise and speech files is shuffled after each epoch and mixed online in the given snr range. In my case the samples/files had a length of 4 s instead of 15/30 s. At the moment I don't have the time to integrate procedure into the code. It will be part of the refactoring of the code, but probably it will take some time until it can be pushed. The training with 40h of data still takes around 14h on a Titan V or 2080 Ti.

hchintada commented 4 years ago

OK. I'm able to run this model in tfjs. But the inference is close to 15ms for a 32ms block. And the block-shift overlap-add is causing further delay. For a 4096 sized chunk, nunmblocks is 29, meaning those many calls to model.predict(). Completely missing a lot of audio during the lag. So, I'm trying to avoid the block-shift operations at inference time when implementing it in tfjs.

to avoid block shift, I want to retrain the network. While training, I can pass the input signal directly to the tf.signal.stft() without using overlap in the stftLayer?

breizhn commented 4 years ago

OK. I'm able to run this model in tfjs. But the inference is close to 15ms for a 32ms block. And the block-shift overlap-add is causing further delay. For a 4096 sized chunk, nunmblocks is 29, meaning those many calls to model.predict(). Completely missing a lot of audio during the lag. So, I'm trying to avoid the block-shift operations at inference time when implementing it in tfjs.

to avoid block shift, I want to retrain the network. While training, I can pass the input signal directly to the tf.signal.stft() without using overlap in the stftLayer?

In tf.signal.stft()you have to set the block-shift to same value as the block-length. But the newly trained network will probably not have the same performance as the pretrained in the repo.

shilsircar commented 4 years ago

To perform stateless inference and not do block shift outside the model all you have to do is use stateless mode in dtln and pass in 32 ms blocks of 512 size. Why do you have to retrain ? Results without states may not be great but it will work well within 32 ms in tjfs perf perspective. Don't see what retraining in this case will achieve? @breizhn is this accurate understanding ?

hchintada commented 4 years ago

@shilsircar I'm not sure if what you say is correct. I'm new to audio processing and my plan to retrain was based on what @breizhn said. That if you want to get rid of the block_shift, you have to re-train the model. Does block_shift have anything to do with the LSTM's states?

breizhn commented 4 years ago

No the block-shift has nothing to do with the states.

To create a model with a new shift value, the model must be retrained, because it was fitted to the shift. The audio will be bumpy used with another shift as the one it was trained with.

Stateless models are used during training, because the training is performed on sequences. For blockwise real time inference a stateful model is required. The model can be set to stateful as the SavedModel or the states can be handled outside the model as for the TF-lite or ONNX models.

hchintada commented 4 years ago

OK. I'm able to run this model in tfjs. But the inference is close to 15ms for a 32ms block. And the block-shift overlap-add is causing further delay. For a 4096 sized chunk, nunmblocks is 29, meaning those many calls to model.predict(). Completely missing a lot of audio during the lag. So, I'm trying to avoid the block-shift operations at inference time when implementing it in tfjs. to avoid block shift, I want to retrain the network. While training, I can pass the input signal directly to the tf.signal.stft() without using overlap in the stftLayer?

In tf.signal.stft()you have to set the block-shift to same value as the block-length. But the newly trained network will probably not have the same performance as the pretrained in the repo.

In terms of performance, it looks like we have to compromise a bit, to get it running in a browser environment. On a server, it works wonders though, your model!

If i'm removing the shift altogether in the stft, i need not add the overlapadd layer at the end right?

breizhn commented 4 years ago

Yes that is correct. For the training you can just use the setup as is. The overlapadd can handle blockshift = blocklen. In real time setup you can pass the output of the last Conv1d layer without any further processing to the soundcard or whatever.

hchintada commented 4 years ago

Was having issues downloading the DNS-challenge dataset via git-lfs. Meanwhile, with the current model, setting buffersize to 512 and block-shift also to 512, just checked the inference time in JS. It is still 17 ms (Though denoising effect is still there, the speech is flickering.) Then replaced the fft and ifft layers with dummy layers to gauge the inference times of the separation_kernel and conv layers. the inf time now is 9.5 ms. My fft/ifft calculations in JS for sure need optimization, but the model's inf time is still not ideal I think? For live audio processing on a browser

shilsircar commented 4 years ago

Was having issues downloading the DNS-challenge dataset via git-lfs. Meanwhile, with the current model, setting buffersize to 512 and block-shift also to 512, just checked the inference time in JS. It is still 17 ms (Though denoising effect is still there, the speech is flickering.) Then replaced the fft and ifft layers with dummy layers to gauge the inference times of the separation_kernel and conv layers. the inf time now is 9.5 ms. My fft/ifft calculations in JS for sure need optimization, but the model's inf time is still not ideal I think? For live audio processing on a browser

Each layer is about ~1 ms in JS and there are 11 of them so this would be expected given the number of weights. What do u mean with current model audio is flickering ? With current model don't think without block shift will produce intelligent audio leave alone flicker. So curious what you mean by that.

The rfft and irfft u can use tf lib if u do they will be on GPU and is sufficiently fast. Tf.DataSync is the main problem it consumes a good 4-5 ms.

hchintada commented 4 years ago

Flicker happened during speech. During silence, the static noise the original had was suppressed. But anyways, it was only an experiment to check the inference times, not noise suppression quality - which we know won't work properly as the model was trained with a block-shift.

main drag was the ifftLayer. TF doesn't support exp of complex numbers on CPUs or even WebGL, and had to be calculated on JS arrays, for which .data() or .dataSync() were essential. A good 5-6 ms gone here

hchintada commented 4 years ago

I've tried shrinking the network by replacing the two LSTM layers in each of the separation cores with a single conv1d comprising 128 filters. This network had less than 0.4 million params. When I trained it with around 9000 audio samples (avg of 15ms each), it stopped learning after the 6th epoch.

May be I was too greedy for quick inference.

@breizhn

hchintada commented 4 years ago

@breizhn I'd like to know more about the B4 model you mentioned in the paper, which doesn't use the stft analysis-synthesis basis. Does it use the same signal transformation proposed by Luo and colleagues in the first separation kernel also? If yes, we can pass the time domain data directly to the separation kernel? We can get rid of the fft and ifft calculations by building that architecture?

hchintada commented 4 years ago

This is the architecture of B4, from what I gather from the paper:

Please correct me if I'm wrong:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, 512)    0           input_1[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    131072      lambda[0][0]
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 95)     133760      conv1d[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 95)     0           lstm[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 95)     72580       dropout[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 256)    24576       lstm_1[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, None, 256)    0           dense[0][0]
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           conv1d[0][0]
                                                                 activation[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 512)    131072      multiply[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, None, 256)    131072      conv1d_1[0][0]
__________________________________________________________________________________________________
instant_layer_normalization (In (None, None, 256)    512         conv1d_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, None, 95)     133760      instant_layer_normalization[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, 95)     0           lstm_2[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, None, 95)     72580       dropout_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 256)    24576       lstm_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, 256)    0           dense_1[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, None, 256)    0           conv1d_2[0][0]
                                                                 activation_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, None, 512)    131072      multiply_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None)         0           conv1d_3[0][0]
==================================================================================================
Total params: 986,632
Trainable params: 986,632
Non-trainable params: 0
breizhn commented 4 years ago

This is the architecture of B4, from what I gather from the paper:

Please correct me if I'm wrong:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, 512)    0           input_1[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    131072      lambda[0][0]
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 95)     133760      conv1d[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 95)     0           lstm[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 95)     72580       dropout[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 256)    24576       lstm_1[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, None, 256)    0           dense[0][0]
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           conv1d[0][0]
                                                                 activation[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 512)    131072      multiply[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, None, 256)    131072      conv1d_1[0][0]
__________________________________________________________________________________________________
instant_layer_normalization (In (None, None, 256)    512         conv1d_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, None, 95)     133760      instant_layer_normalization[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, 95)     0           lstm_2[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, None, 95)     72580       dropout_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 256)    24576       lstm_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, 256)    0           dense_1[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, None, 256)    0           conv1d_2[0][0]
                                                                 activation_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, None, 512)    131072      multiply_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None)         0           conv1d_3[0][0]
==================================================================================================
Total params: 986,632
Trainable params: 986,632
Non-trainable params: 0

You have to add a layer normalization between the first Conv1D layer and the first LSTM layer, in the same fashion as in the second part of the model. Yes in this manner you can get rid of the STFT, but model also looses its original idea of using STFT and a learned feature representation.

hchintada commented 4 years ago

OK.

Yes, the original idea of combining STFT with learned features is lost, but we will have something to work with in the browser domain, thanks to the approach described in your paper.

hchintada commented 4 years ago

here is the updated B4:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, 512)    0           input_1[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    131072      lambda[0][0]
__________________________________________________________________________________________________
instant_layer_normalization (In (None, None, 256)    512         conv1d[0][0]
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 95)     133760      instant_layer_normalization[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 95)     0           lstm[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 95)     72580       dropout[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 256)    24576       lstm_1[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, None, 256)    0           dense[0][0]
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           conv1d[0][0]
                                                                 activation[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 512)    131072      multiply[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, None, 256)    131072      conv1d_1[0][0]
__________________________________________________________________________________________________
instant_layer_normalization_1 ( (None, None, 256)    512         conv1d_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, None, 95)     133760      instant_layer_normalization_1[0][
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, 95)     0           lstm_2[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, None, 95)     72580       dropout_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 256)    24576       lstm_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, 256)    0           dense_1[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, None, 256)    0           conv1d_2[0][0]
                                                                 activation_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, None, 512)    131072      multiply_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None)         0           conv1d_3[0][0]
==================================================================================================
Total params: 987,144
Trainable params: 987,144
Non-trainable params: 0
hchintada commented 4 years ago

Also, kindly confirm if B2 architecture is similar to this:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, 512)    0           input_1[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    131072      lambda[0][0]
__________________________________________________________________________________________________
instant_layer_normalization (In (None, None, 256)    512         conv1d[0][0]
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 139)    220176      instant_layer_normalization[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 139)    0           lstm[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 139)    155124      dropout[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, None, 139)    155124      lstm_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, 139)    0           lstm_2[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, None, 139)    155124      dropout_1[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 256)    35840       lstm_3[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, None, 256)    0           dense[0][0]
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           conv1d[0][0]
                                                                 activation[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 512)    131072      multiply[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None)         0           conv1d_1[0][0]
==================================================================================================
Total params: 984,044
Trainable params: 984,044
Non-trainable params: 0
__________________________________________________________________________________________________
breizhn commented 4 years ago

Between LSTM_1 and LSTM_2 a Dropout is missing.

Because this topic now shifted to the baseline networks and isn't about transfer-learning any more, I will close the Issue. If you would like to discuss the baseline networks, please open a new issue.

ghost commented 3 years ago

Hi,@hchintada,do you know more details about model b3 in paper?