khaotik / DaNet-Tensorflow

Tensorflow implementation of "Speaker-independent Speech Separation with Deep Attractor Network"
MIT License
89 stars 41 forks source link

Constant traning loss ? #5

Closed menon92 closed 7 years ago

menon92 commented 7 years ago

We are using this implementation for training a model using our custom dataset with the default network configuration. But it seems that the training loss is constant even after 1600 epochs.

NB: We are using CPU machine .

khaotik commented 7 years ago

Hi,

What might be the possible cause ?

Does running on TIMIT dataset, with bilstm-orig encoder, work for you on CPU? (Should see convergence in ~100 epochs)

If yes, then check your data preprocessing and feeding pipeline.

And how the loss is calculated ?

We compute permutation-invariant MSE loss on original signals and estimated signals. For example:

Input mixture is A + B, estimation is (X, Y), then loss term is:

min( mean(||A - X||^2) + mean(||B - Y||^2), mean(||A - Y||^2) + mean(||B - X||^2) )

Where A, B, X, Y are all complex valued STFT spectra.

Unfortunately, the original paper is a bit unclear on loss term. I'm not sure if this accurately matches their implementation or not.

menon92 commented 7 years ago

@khaotik you are right. I was using toy encoding for TIMIT dataset. Now I am trying with bilstm-orig encoding.

Thanks for your quick reply :)

menon92 commented 7 years ago

with the bilstm-orig encoding starting was good.

Preparing dataset "timit" ... done
Encoder type: "bilstm-orig"
Separator type: "dot-sigmoid-orig"
Training estimator type: "truth-weighted"
Inference estimator type: "anchor"
Building model ... done
Set learning rate to 0.000300
:::::::::::::::::::::::::::S
Epoch 1/120 loss=939.576533565 SNR=3.20550452338
..........
Valid  1/120 loss=1280.94570313 SNR=2.1614107132
:::::::::::::::::::::::::::S
Epoch 2/120 loss=709.838252315 SNR=3.49441754376
..........
Valid  2/120 loss=1361.55566406 SNR=1.90109233856
:::::::::::::::::::::::::::S
Epoch 3/120 loss=611.47229456 SNR=3.62280669036
..........
Valid  3/120 loss=1548.0 SNR=1.51264925003

But after 120 epoch why loss is not decreasing ? Here is my terminal output.

Epoch 119/120 loss=368.527162905 SNR=6.23708654333
..........
Valid  119/120 loss=2095.86699219 SNR=0.275619983673
:::::::::::::::::::::::::::S
Epoch 120/120 loss=367.290364583 SNR=5.97982957628
..........
Valid  120/120 loss=2115.9171875 SNR=0.259754371643
Saving parameters into models_timit/model_timit.ckpt ... done

What would be the possible cause ?

khaotik commented 7 years ago

I would really appreciate it if someone knows how to fully reproduce the original paper.