thu-spmi / CAT

A CRF-based ASR Toolkit
Apache License 2.0
326 stars 74 forks source link

Hard to converge? Is it possible to release model configs or trained models? #1

Closed glynpu closed 4 years ago

glynpu commented 4 years ago

Hi, I have tried different model topologies and found it's really hard to converge.

I am reproducing results reported in papers "_CAT: CRF-BASED ASR TOOLKIT_I" with Aishell dataset. To be exact, I am reproducing Table 4 in that paper. However, I was only able to reproducing the BLSTM model(pointed by red arrow in following graph).

image

Here is the convergence of my successful experiments with BLSTM. It only takes one epoch to reach cer 9.9%, but it takes another 16 epochs to reach cer 7.42%.

image

And I failed to get a decent result with LSTM, VGG-LSTM, TDNN-LSTM models. I have tried different initial learning rate, decay strategy, ctc objective weight lambda. But they just failed to converge.

Could you release associate configs for these models? If together with trained models would be much more appreciated.

Thank you!

aky15 commented 4 years ago

Thanks for your interest in our work! For the convergence issue, as the paper "CAT: CRF-BASED ASR TOOLKIT" states, "we use batch normalization to speed up convergence". We found in the experiment that batchnorm is also benefical for the model accuracy. Relevant code will be released soon. For the aishell experiment, the hyperparameters are set as follows: hdim=320 lambda=0.01 initial learning rate:0.001

glynpu commented 4 years ago

Thanks for your helpful information! I Still have two questions? Question 1: Does BLSTM / LSTM / VGG-LSTM / TDNN-LSTM models share the SAME hyper-parameter config(hdim=320, labmda=0.01, initial lr=0.001)? After successfully train BLSTM model with aishell, I just replace the model definition part and never touch these hyper-parameters. But for BLSTM, it converges, while for LSTM / VGG-LSTM / TDNN-LSTM it doesn't.

Question 2: Empirically, with Aishell data how many epochs does it need to train LSTM / VGG-LSTM / TDNN-LSTM models?
I wander do I stop the training too early with 30~50 epochs? Maybe they are in convergence but I killed them.

aky15 commented 4 years ago

For the question 1, the answer is "yes". Although in the later experiment we found that more elaborate learning rate design can potencially improve the performance, for the experiments repoeted in the paper, we used the same hyper-parameters. For the question 2, empirically, it takes about 20~25 epochs to finish the training. Have you applied offline CMVN to the raw fbank features? It could make the training process a little faster (maybe less than 20 epochs). Since it requires a full sentence input to extract the mean and variance information we excluded it in our latency-control experiment and use batch-norm instead.

glynpu commented 4 years ago

CMVN is used in my experiments. I think that's why it reaches cer 7.4 at 17 epochs with BLSTM. Just as you said "maybe less than 20 epochs". I will do some experiments without CMVN to compare it's effect. Btw, which method do you use to add Batch Normalization to LSTM? Could you share me some tutorials or reference examples of your solutions?

glynpu commented 4 years ago

TDNN-LSTM still fails to converge. Questions:

   - How to define a TDNN-LSTM model reported in paper?
   - Is there something wrong with my model definition methods (following 5th step)?

Following are my experiment steps:

  1. re-clone and re-build this repository.

  2. Copy previously generated data/ folder to CAT/egs/aishell

  3. change hdim in steps/train.py from default 512 to 320. (by default, lr is 0.001 in steps/train.py, lamb is 0.01 by default in run.sh. Both are suggested values)

  4. run stage 6 in run.sh

    if [ $stage -eq 6 ]; then 
       echo "nn training."
       python steps/train.py --output_unit=218 --lamb=0.01 --data_path=$dir
    fi

    With previously 4 steps(i.e. BLSTM model), the _tr_realloss decrease rapidly. Within only several hundred steps, it can decrease below 100. However, I fail to train a TDNN-LSTM with following step 5-6

  5. Modify steps/train.py to define TDNN-LSTM Definition of BLSTM in steps/train.py

    self.net = BLSTM(idim,  hdim, n_layers, dropout=0.5)  
    self.linear = nn.Linear(hdim*2, K)

    To define TDNN-LSTM model, Above definition is modified to following

    self.net = LSTMrowCONV(idim,  hdim, n_layers, dropout=0.5)
    self.linear = nn.Linear(hdim, K)
  6. run stage 6 in run.sh to train the new defined TDNN-LSTM model.

However, with step 5-6, TDNN-LSTM failed to converge. To be exact, tr_real_loss is stuck at around 75. As show by following figures.

The initial tr_real_loss is around 111.7 image

It only takes around 15 steps to decrease to 75. image

However, it stuck there not matter how long I wait! image

Is there something wrong with my model definition methods? Could you give me some clue to solve this issue?

aky15 commented 4 years ago

We add the implementation for TDNN_LSTM, see steps/model.py for details. BTW, we found in our experiment that LSTMrowCONV is difficult to converge. By far we only conduct experiment successfully on WSJ.

glynpu commented 4 years ago

Thanks for the update. It's my mistake that I thought LSTMrowCONV was TDNN_LSTM.

However, with TDNN_LSTM model defined in steps/model.py, which you add last night, tr_real_loss is still stuck at around 70. It has been trained more than 4 epochs. image

This experiment is conducted on Aishell. I have two questions: Q1: Is It normal that tr_real_loss stuck around 70 at the beginning several epochs(1~4 epoch)? Q2: Should I wait util it converges or do something else?

aky15 commented 4 years ago

We have added the implementation for batchnorm layer in the TDNN_LSTM, see steps/model.py for details. After applying batchnorm, the convergence will be a lot easier. Before that, make sure you have complied the source code: cd src/batchnorm_src and change the path in Makefile to your local path. make

glynpu commented 4 years ago

Problems are solved after applying batchnorm! Thanks for you kindness and support@aky15.