OpenNMT / OpenNMT-py

Open Source Neural Machine Translation and (Large) Language Models in PyTorch
https://opennmt.net/
MIT License
6.76k stars 2.25k forks source link

GRU performance with default params is bad (sensitive to `-max_grad_norm`)? #489

Closed howardyclo closed 6 years ago

howardyclo commented 6 years ago

Hello, my task is grammatical error correction, where the source and target sentence is the same language. Here is the example: The source is the erroneous tokenized sentence, e.x. "I loves opennmt ." The target is the correct tokenized sentence, e.x. "I love opennmt ."

I've tested several models using LSTM-based encoder/decoder, and the performance is fine (The training accuracy can be achieved to 90% and the losses are generally low).

But, when I use GRU-based encoder/decoder, it seems that the model could not fit the data, which is very weird... I supposed that there might be some bug in GRU?

Here is my script:

python $OPENNMT_HOME/train.py \
  -data $OPENNMT_HOME/data/efcamdat2.changed \
  -save_model $OPENNMT_HOME/gec-experiment/checkpoints/wordnmt1 \
  -gpuid 0 \
  -seed 1 \
  -batch_size 64 \
  -max_generator_batches 32 \
  -epochs 16 \
  -optim 'adam' \
  -learning_rate 0.001 \
  -word_vec_size 300 \
  -max_grad_norm 2 \
  -encoder_type 'brnn' \
  -decoder_type 'rnn' \
  -enc_layers 2 \
  -dec_layers 2 \
  -rnn_size 512 \
  -rnn_type 'GRU' \
  -input_feed 1 \
  -global_attention 'general' \

(Here I only modify the -rnn_type from LSTM to GRU, and the performance became a lot worse.)

And here is my training log:

Loading train data from '/home/howard/gec/nmt.gec/data/efcamdat2.changed'
 * number of train sentences: 2430112
Loading valid data from '/home/howard/gec/nmt.gec/data/efcamdat2.changed'
 * number of valid sentences: 2966
 * maximum batch size: 64
 * vocabulary size. source = 50004; target = 50004
Building model...
Intializing model parameters.
NMTModel (
  (encoder): RNNEncoder (
    (embeddings): Embeddings (
      (make_embedding): Sequential (
        (emb_luts): Elementwise (
          (0): Embedding(50004, 300, padding_idx=1)
        )
      )
    )
    (rnn): GRU(300, 256, num_layers=2, dropout=0.3, bidirectional=True)
  )
  (decoder): InputFeedRNNDecoder (
    (embeddings): Embeddings (
      (make_embedding): Sequential (
        (emb_luts): Elementwise (
          (0): Embedding(50004, 300, padding_idx=1)
        )
      )
    )
    (dropout): Dropout (p = 0.3)
    (rnn): StackedGRU (
      (dropout): Dropout (p = 0.3)
      (layers): ModuleList (
        (0): GRUCell(812, 512)
        (1): GRUCell(512, 512)
      )
    )
    (attn): GlobalAttention (
      (linear_in): Linear (512 -> 512)
      (linear_out): Linear (1024 -> 512)
      (sm): Softmax ()
      (tanh): Tanh ()
    )
  )
  (generator): Sequential (
    (0): Linear (512 -> 50004)
    (1): LogSoftmax ()
  )
)
* number of parameters: 62093364
encoder:  17041008
decoder:  45052356

/home/howard/.conda/envs/cedl2017/lib/python3.6/site-packages/torch/tensor.py:297: UserWarning: other is not broadcastable to self, but they have the same number of elements.  Fa
lling back to deprecated pointwise behavior.
  return self.add_(other)
/home/howard/.conda/envs/cedl2017/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1907: RuntimeWarning: invalid value encountered in multiply
  lower_bound = self.a * scale + loc
/home/howard/.conda/envs/cedl2017/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1908: RuntimeWarning: invalid value encountered in multiply
  upper_bound = self.b * scale + loc

Epoch: 1
Train loss: 6.42267e+07
Train perplexity: 6.27111
Train accuracy: 73.3097
Validation loss: 212688
Validation perplexity: 42.5937
Validation accuracy: 53.2024
Validation GLEU avg score: 0.211474

Epoch  2, 37950/37971; loss: 298819.47; acc:  19.42; ppl: 539.59; 8311 src tok/s; 9084 tgt tok/s;   3912 s elapsed
Epoch: 2
Train loss: 1.72665e+08
Train perplexity: 139.174
Train accuracy: 34.1212
Validation loss: 400135
Validation perplexity: 1162.32
Validation accuracy: 14.893
Validation GLEU avg score: 0.0637938
Decaying learning rate to 0.0005

Epoch  3, 37950/37971; loss: 255667.19; acc:  22.78; ppl: 297.92; 8117 src tok/s; 8964 tgt tok/s;   3912 s elapsed
Epoch: 3
Train loss: 1.97887e+08
Train perplexity: 286.199
Train accuracy: 23.6517
Validation loss: 395204
Validation perplexity: 1065.5
Validation accuracy: 15.969
Validation GLEU avg score: 0.0693773
Decaying learning rate to 0.00025

Epoch  4, 37950/37971; loss: 247323.76; acc:  25.83; ppl: 215.78; 8373 src tok/s; 9183 tgt tok/s;   3915 s elapsed
Epoch: 4
Train loss: 1.90495e+08
Train perplexity: 231.691
Train accuracy: 25.0226
Validation loss: 379404
Validation perplexity: 806.328
Validation accuracy: 17.6113
Validation GLEU avg score: 0.0732429
Decaying learning rate to 0.000125

Epoch  5, 37950/37971; loss: 254888.83; acc:  25.76; ppl: 210.58; 8411 src tok/s; 9191 tgt tok/s;   3919 s elapsed
Epoch: 5
Train loss: 1.84918e+08
Train perplexity: 197.545
Train accuracy: 26.3549
Validation loss: 373275
Validation perplexity: 723.702
Validation accuracy: 18.0858
Validation GLEU avg score: 0.0721813
Decaying learning rate to 6.25e-05

Epoch  6, 37950/37971; loss: 257786.33; acc:  25.67; ppl: 210.85; 8459 src tok/s; 9210 tgt tok/s;   3917 s elapsed
Epoch: 6
Train loss: 1.82185e+08
Train perplexity: 182.703
Train accuracy: 26.8714
Validation loss: 369117
Validation perplexity: 672.518
Validation accuracy: 17.9641
Validation GLEU avg score: 0.0734583
Decaying learning rate to 3.125e-05

Epoch  7, 37950/37971; loss: 201073.04; acc:  29.94; ppl: 124.42; 8186 src tok/s; 9084 tgt tok/s;   3903 s elapsed
Epoch: 7
Train loss: 1.79903e+08
Train perplexity: 171.165
Train accuracy: 27.3309
Validation loss: 367061
Validation perplexity: 648.565
Validation accuracy: 18.5426
Validation GLEU avg score: 0.0740903
Decaying learning rate to 1.5625e-05

Epoch  8, 37950/37971; loss: 237486.40; acc:  27.29; ppl: 162.17; 8455 src tok/s; 9261 tgt tok/s;   3918 s elapsed
Epoch: 8
Train loss: 1.7893e+08
Train perplexity: 166.467
Train accuracy: 27.5335
Validation loss: 365904
Validation perplexity: 635.469
Validation accuracy: 18.5938
Validation GLEU avg score: 0.0751541
Decaying learning rate to 7.8125e-06

Epoch  9, 37950/37971; loss: 242571.93; acc:  26.84; ppl: 168.96; 8397 src tok/s; 9156 tgt tok/s;   3919 s elapsed
Epoch: 9
Train loss: 1.78053e+08
Train perplexity: 162.349
Train accuracy: 27.6563
Validation loss: 365670
Validation perplexity: 632.843
Validation accuracy: 18.6379
Validation GLEU avg score: 0.0753083
Decaying learning rate to 3.90625e-06

Epoch 10, 37950/37971; loss: 228999.42; acc:  28.09; ppl: 159.13; 8138 src tok/s; 8916 tgt tok/s;   3922 s elapsed
Epoch: 10
Train loss: 1.77839e+08
Train perplexity: 161.357
Train accuracy: 27.6894
Validation loss: 365342
Validation perplexity: 629.202
Validation accuracy: 18.622
Validation GLEU avg score: 0.0742992
Decaying learning rate to 1.95313e-06

Epoch 11, 37950/37971; loss: 207980.83; acc:  29.45; ppl: 129.24; 8064 src tok/s; 8940 tgt tok/s;   3924 s elapsed
Epoch: 11
Train loss: 1.77593e+08
Train perplexity: 160.226
Train accuracy: 27.7269
Validation loss: 365030
Validation perplexity: 625.748
Validation accuracy: 18.4879
Validation GLEU avg score: 0.0751217
Decaying learning rate to 9.76563e-07

The above log shows that the accuracy is getting lower and the ppl is getting higher...

And this is part of my model's output for validation data (the output seems to be fine):

0
SRC: So I think we can not live if old people could not find siences and tecnologies and they did not developped .
PRED: I think in the good people .
REF0: So I think we would not be alive if our ancestors did not develop sciences and technologies .
REF1: So I think we could not live if older people did not develop science and technologies .
REF2: So I think we can not live if old people could not find science and technologies and they did not develop .
REF3: So I think we can not live if old people can not find the science and technology that has not been developed .
GLEU score:['0.041978', '0.013859', '(0.015,0.069)']

1
SRC: For not use car .
PRED: I have a police .
REF0: Not for use with a car .
REF1: Do not use in the car .
REF2: Car not for use .
REF3: Can not use the car .
GLEU score:['0.248221', '0.035531', '(0.179,0.318)']

2
SRC: Here was no promise of morning except that we looked up through the trees we saw how low the forest had swung .
PRED: Here morning at the big .
REF0: Here was no promise of morning , except that we looked up through the trees , and we saw how low the forest had <unk> .
REF1: Here , there was no promise of morning , except that we looked up through the trees and saw how low the forest had <unk> .
REF2: Here was no promise of morning except that we looked up through the trees and we saw how low the forest had <unk> .
REF3: There was no promise of morning except when we looked up through the trees and saw how low the forest had <unk> .
GLEU score:['0.013847', '0.002265', '(0.009,0.018)']

3
SRC: Thus even today sex is considered as the least important topic in many parts of India .
PRED: I have a important .
REF0: Thus , even today , sex is considered as the least important topic in may parts of India .
REF1: Thus , even today , sex is considered the least important topic in many parts of India .
REF2: Thus , even today , sex is considered the least important topic in many parts of India .
REF3: Thus , even today sex is considered as the least important topic in many parts of India .
GLEU score:['0.025477', '0.002095', '(0.021,0.030)']

4
SRC: image you salf you are wark in factory just to do one thing like pot taire on car if they fire you you will destroy , becouse u dont know more than pot taire in car .
PRED: How is a great .
REF0: Imagine yourself you are working in factory just to do one thing like put air a on car if they fire you you will be destroyed , because you do n't know more than to put air a in car .
REF1: Imagine that you work in a factory and do just one thing , like put tires on cars ; if they fire you , they will destroy you because you do n't know how to do anything but put tires on cars .
REF2: image you <unk> you are wark in factory just to do one thing like pot <unk> on car if they fire you you will destroy , becouse u <unk> know more than pot <unk> in car .
REF3: Imagine yourself working in a <unk> You are to do just one thing , such as put a tire on a <unk> If you are fired , it will destroy you because you do not know how to do more than put tires on cars .
GLEU score:['0.000262', '0.000151', '(-0.000,0.001)']
helson73 commented 6 years ago

Did you try tuning hyperparameters? Like initial weight, optim and learning rate, those used in GRU models in other work may help.

howardyclo commented 6 years ago

@helson73 Ok, I found the problem. I re-run the experiment by modifying the -max_grad_norm back to the default value (5). And the GRU performance becomes fine. (Can decrease the training loss). But don't know why the -max_grad_norm could effect so much on GRU unit.

(LSTM is fine: In my experiment, I tuned the -max_grad_norm from 5 to 2 and used the LSTM unit, the performance becomes slightly better.)

helson73 commented 6 years ago

Its just my experience but seems like GRUs are often more sensitive to hyperparameters, for instance, vanilla SGD actually doesn't work for GRUs in many cases, but fine to LSTMs.

howardyclo commented 6 years ago

@helson73 Thanks for the observation. The reason why I choose to try GRU unit here is because there're several grammatical error correction papers use GRU-based seq2seq. But it seems that GRU needs more tuning. Um... I'll go for LSTM (lol)

helson73 commented 6 years ago

They use GRU because its faster and simple, use hyperparameters they mentioned should work.

srush commented 6 years ago

This is an interesting discussion. Our default hyperaparameters are for LSTM, let's add a note suggesting some changes for GRU.

Victoria-Pinzhen-Liao commented 6 years ago

Hi, sorry but may I know whether you wrote the Validation GLEU avg score, please?