flashlight / wav2letter

Facebook AI Research's Automatic Speech Recognition Toolkit
https://github.com/facebookresearch/wav2letter/wiki
Other
6.39k stars 1.01k forks source link

Learning rate too low when using "continue" #639

Closed TijsRozenbroek closed 4 years ago

TijsRozenbroek commented 4 years ago

I have trained the provided seq2seq TDS model (from https://github.com/facebookresearch/wav2letter/tree/master/recipes/models/seq2seq_tds) on my own data for 200 epochs with a learning rate of 0.2 using the following config parameters (left out irrelevant rundir parameters etc.):

--samplerate=8000
--criterion=seq2seq
--lr=0.2
--lrcrit=0.2
--momentum=0.0
--warmup=0
--stepsize=203740
--gamma=0.5
--maxgradnorm=5
--mfsc=true
--dataorder=output_spiral
--inputbinsize=25
--filterbanks=80
--attention=keyvalue
--encoderdim=512
--attnWindow=softPretrain
--softwstd=4
--trainWithWindow=true
--pretrainWindow=15281
--maxdecoderoutputlen=120
--usewordpiece=true
--wordseparator=_
--sampletarget=0.01
--target=ltr
--batchsize=2
--labelsmooth=0.05
--nthread=4
--memstepsize=4194304
--eostoken=true
--pcttraineval=1
--pctteacherforcing=99
--iter=1018700
#--enable_distributed=true

The last epoch is logged as follows:

I0504 10:23:52.138806 31888 Train.cpp:342] epoch:      200 | nupdates:      1018701 | lr: 0.005933 | lrcriterion: 0.005933 | runtime: 00:04:55 | bch(ms): 59.19 | smp(ms): 2.10 | fwd(ms): 20.14 | crit-fwd(ms): 2.82 | bwd(ms): 28.86 | optim(ms): 7.51 | loss:   75.10920 | train-LER: 33.17 | train-WER: 40.45 | atccvalid-loss:  186.03168 | atccvalid-LER: 79.31 | atccvalid-WER: 89.79 | atcosimvalid-loss:   15.85127 | atcosimvalid-LER: 18.56 | atcosimvalid-WER: 24.51 | avg-isz: 886 | avg-tsz: 143 | max-tsz: 5996 | hrs:   24.62 | thrpt(sec/sec): 299.70

After this I want to train for another 100 epochs with a learning rate starting at 0.005, using the following config parameters (again left out irrelevant rundir parameters etc.):

--samplerate=8000
--criterion=seq2seq
--lr=0.005
--lrcrit=0.005
--momentum=0.0
--warmup=0
--stepsize=203740
--gamma=0.5
--maxgradnorm=5
--mfsc=true
--dataorder=output_spiral
--inputbinsize=25
--filterbanks=80
--attention=keyvalue
--encoderdim=512
--attnWindow=softPretrain
--softwstd=4
--trainWithWindow=true
--pretrainWindow=15281
--maxdecoderoutputlen=120
--usewordpiece=true
--wordseparator=_
--sampletarget=0.01
--target=ltr
--batchsize=2
--labelsmooth=0.05
--nthread=4
--memstepsize=4194304
--eostoken=true
--pcttraineval=1
--pctteacherforcing=99
--iter=1528050
#--enable_distributed=true

However when running the command wav2letter/build/Train continue atc/run/seq2seq_tds_distributed_atc_lr02it200/ --flagsfile /home/tijs/atc/config/train.cfg and inspecting the training log, it becomes clear that the learning rate seems incorrect, as it starts at 0.000182 See the log below:

Log file created at: 2020/05/06 17:41:00
Running on machine: speedy
Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg
I0506 17:41:00.887352 10008 Train.cpp:70] reload path is atc/run/seq2seq_tds_distributed_atc_lr02it200/001_model_last.bin
I0506 17:41:00.887476 10008 Train.cpp:77] Reading flags from config file atc/run/seq2seq_tds_distributed_atc_lr02it200/001_model_last.bin
I0506 17:41:00.887661 10008 Train.cpp:80] Parsing command line flags
I0506 17:41:00.887665 10008 Train.cpp:81] Overriding flags should be mutable when using `continue`
I0506 17:41:00.887675 10008 Train.cpp:85] Reading flags from file atc/config/train.cfg
I0506 17:41:01.462204 10008 Train.cpp:148] Gflags after parsing 
--flagfile=; --fromenv=; --tryfromenv=; --undefok=; --tab_completion_columns=80; --tab_completion_word=; --help=false; --helpfull=false; --helpmatch=; --helpon=; --helppackage=false; --helpshort=false; --helpxml=false; --version=false; --adambeta1=0.90000000000000002; --adambeta2=0.999; --am=; --am_decoder_tr_dropout=0; --am_decoder_tr_layerdrop=0; --am_decoder_tr_layers=1; --arch=network.arch; --archdir=/home/tijs/atc/config; --attention=keyvalue; --attentionthreshold=2147483647; --attnWindow=softPretrain; --attnconvchannel=0; --attnconvkernel=0; --attndim=0; --batchsize=2; --beamsize=2500; --beamsizetoken=250000; --beamthreshold=25; --blobdata=false; --channels=1; --criterion=seq2seq; --critoptim=sgd; --datadir=; --dataorder=output_spiral; --decoderattnround=1; --decoderdropout=0; --decoderrnnlayer=1; --decodertype=wrd; --devwin=0; --emission_dir=; --emission_queue_size=3000; --enable_distributed=false; --encoderdim=512; --eosscore=0; --eostoken=true; --everstoredb=false; --fftcachesize=1; --filterbanks=80; --flagsfile=atc/config/train.cfg; --framesizems=25; --framestridems=10; --gamma=0.5; --gumbeltemperature=1; --input=flac; --inputbinsize=25; --inputfeeding=false; --isbeamdump=false; --iter=1528050; --itersave=false; --labelsmooth=0.050000000000000003; --leftWindowSize=50; --lexicon=/home/tijs/atc/am/fulllexicon.txt; --linlr=-1; --linlrcrit=-1; --linseg=0; --lm=; --lm_memory=5000; --lm_vocab=; --lmtype=kenlm; --lmweight=0; --localnrmlleftctx=0; --localnrmlrightctx=0; --logadd=false; --lr=0.0050000000000000001; --lr_decay=9223372036854775807; --lr_decay_step=9223372036854775807; --lrcosine=false; --lrcrit=0.0050000000000000001; --maxdecoderoutputlen=120; --maxgradnorm=5; --maxisz=9223372036854775807; --maxload=-1; --maxrate=10; --maxsil=50; --maxtsz=9223372036854775807; --maxword=-1; --melfloor=1; --memstepsize=4194304; --mfcc=false; --mfcccoeffs=13; --mfsc=true; --minisz=0; --minrate=3; --minsil=0; --mintsz=0; --momentum=0; --netoptim=sgd; --noresample=false; --nthread=4; --nthread_decoder=1; --nthread_decoder_am_forward=1; --numattnhead=8; --onorm=none; --optimepsilon=1e-08; --optimrho=0.90000000000000002; --outputbinsize=5; --pctteacherforcing=99; --pcttraineval=1; --pow=false; --pretrainWindow=15281; --replabel=0; --reportiters=0; --rightWindowSize=50; --rndv_filepath=; --rundir=/home/tijs/atc/run; --runname=seq2seq_tds_distributed_atc_lr02it200; --samplerate=8000; --sampletarget=0.01; --samplingstrategy=rand; --saug_fmaskf=27; --saug_fmaskn=2; --saug_start_update=-1; --saug_tmaskn=2; --saug_tmaskp=1; --saug_tmaskt=100; --sclite=; --seed=0; --show=false; --showletters=false; --silscore=0; --smearing=none; --smoothingtemperature=1; --softwoffset=10; --softwrate=5; --softwstd=4; --sqnorm=false; --stepsize=203740; --surround=; --tag=; --target=ltr; --test=; --tokens=tokens.txt; --tokensdir=/home/tijs/atc/am; --train=/home/tijs/atc/lists/atcctrain.lst,/home/tijs/atc/lists/atcosimtrain.lst; --trainWithWindow=true; --transdiag=0; --unkscore=-inf; --use_memcache=false; --use_saug=false; --uselexicon=true; --usewordpiece=true; --valid=atccvalid:/home/tijs/atc/lists/atccvalid.lst,atcosimvalid:/home/tijs/atc/lists/atcosimvalid.lst; --warmup=0; --weightdecay=0; --wordscore=0; --wordseparator=_; --world_rank=0; --world_size=1; --alsologtoemail=; --alsologtostderr=false; --colorlogtostderr=false; --drop_log_memory=true; --log_backtrace_at=; --log_dir=; --log_link=; --log_prefix=true; --logbuflevel=0; --logbufsecs=30; --logemaillevel=999; --logfile_mode=436; --logmailer=/bin/mail; --logtostderr=false; --max_log_size=1800; --minloglevel=0; --stderrthreshold=2; --stop_logging_if_full_disk=false; --symbolize_stacktrace=true; --v=0; --vmodule=; 
I0506 17:41:01.462496 10008 Train.cpp:149] Experiment path: atc/run/seq2seq_tds_distributed_atc_lr02it200/
I0506 17:41:01.462499 10008 Train.cpp:150] Experiment runidx: 2
I0506 17:41:01.462786 10008 Train.cpp:196] Number of classes (network): 37
I0506 17:41:01.464396 10008 Train.cpp:203] Number of words: 1928
I0506 17:41:02.013710 10008 Train.cpp:249] [Network] Sequential [input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> (25) -> (26) -> output]
    (0): View (-1 80 1 0)
    (1): Conv2D (1->10, 21x1, 2,1, SAME,SAME, 1, 1) (with bias)
    (2): ReLU
    (3): Dropout (0.200000)
    (4): LayerNorm ( axis : { 0 1 2 } , size : -1)
    (5): Time-Depth Separable Block (21, 80, 10) [800 -> 800 -> 800]
    (6): Time-Depth Separable Block (21, 80, 10) [800 -> 800 -> 800]
    (7): Conv2D (10->14, 21x1, 2,1, SAME,SAME, 1, 1) (with bias)
    (8): ReLU
    (9): Dropout (0.200000)
    (10): LayerNorm ( axis : { 0 1 2 } , size : -1)
    (11): Time-Depth Separable Block (21, 80, 14) [1120 -> 1120 -> 1120]
    (12): Time-Depth Separable Block (21, 80, 14) [1120 -> 1120 -> 1120]
    (13): Time-Depth Separable Block (21, 80, 14) [1120 -> 1120 -> 1120]
    (14): Conv2D (14->18, 21x1, 2,1, SAME,SAME, 1, 1) (with bias)
    (15): ReLU
    (16): Dropout (0.200000)
    (17): LayerNorm ( axis : { 0 1 2 } , size : -1)
    (18): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (19): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (20): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (21): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (22): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (23): Time-Depth Separable Block (21, 80, 18) [1440 -> 1440 -> 1440]
    (24): View (0 1440 1 0)
    (25): Reorder (1,0,3,2)
    (26): Linear (1440->1024) (with bias)
I0506 17:41:02.013763 10008 Train.cpp:250] [Network Params: 36538460]
I0506 17:41:02.013774 10008 Train.cpp:251] [Criterion] Seq2SeqCriterion
I0506 17:41:02.013777 10008 Train.cpp:259] [Network Optimizer] SGD
I0506 17:41:02.013782 10008 Train.cpp:260] [Criterion Optimizer] SGD
I0506 17:41:02.042336 10008 W2lListFilesDataset.cpp:141] 2125 files found. 
I0506 17:41:02.066437 10008 W2lListFilesDataset.cpp:141] 8062 files found. 
I0506 17:41:02.066478 10008 Utils.cpp:102] Filtered 0/10187 samples
I0506 17:41:02.067742 10008 W2lListFilesDataset.cpp:62] Total batches (i.e. iters): 5094
I0506 17:41:02.071892 10008 W2lListFilesDataset.cpp:141] 266 files found. 
I0506 17:41:02.071899 10008 Utils.cpp:102] Filtered 0/266 samples
I0506 17:41:02.071928 10008 W2lListFilesDataset.cpp:62] Total batches (i.e. iters): 133
I0506 17:41:02.075953 10008 W2lListFilesDataset.cpp:141] 1008 files found. 
I0506 17:41:02.075968 10008 Utils.cpp:102] Filtered 0/1008 samples
I0506 17:41:02.076076 10008 W2lListFilesDataset.cpp:62] Total batches (i.e. iters): 504
I0506 17:41:02.076458 10008 Train.cpp:557] Shuffling trainset
I0506 17:41:02.076850 10008 Train.cpp:564] Epoch 201 started!
I0506 17:46:02.718643 10008 Train.cpp:342] epoch:      201 | nupdates:      1023795 | lr: 0.000182 | lrcriterion: 0.000182 | runtime: 00:04:28 | bch(ms): 52.64 | smp(ms): 0.22 | fwd(ms): 19.22 | crit-fwd(ms): 2.44 | bwd(ms): 26.82 | optim(ms): 5.84 | loss:   75.34830 | train-LER: 32.29 | train-WER: 39.03 | atccvalid-loss:  187.12608 | atccvalid-LER: 79.32 | atccvalid-WER: 89.83 | atcosimvalid-loss:   15.90086 | atcosimvalid-LER: 18.68 | atcosimvalid-WER: 24.64 | avg-isz: 892 | avg-tsz: 144 | max-tsz: 5996 | hrs:   25.25 | thrpt(sec/sec): 338.95

I also attempted to pass the learning rate as parameters on the command line using --lr=0.005 and --lrcrit=0.005, to no avail.

When I use fork, the learning rate is correct, however I think using the continue mode would be better, as it is intended exactly for this purpose, as stated in the wiki:

Continue training a saved model. This can be used for example to fine-tune with a smaller learning rate. The continue option makes a best effort to resume training from the most recent checkpoint of a given model as if there were no interruptions.

Could you please tell me whether I am overlooking something, or if something else is wrong. Thanks in advance!

deepspiking commented 4 years ago

I have similar issues but NO changing lr even after changing the configurations. In my case, I changed lr and lrcrit fro .2 to .02 in conf file

--lr=0.02 
--lrcrit=0.02

So I can figured out 00X_config file has new value, 0.02, image image

but log still printing .2 like this.

epoch:      181 | nupdates:       123000 | lr: 0.200000 | lrcriterion: 0.200000 | runtime: 00:05:06 | bch(ms): 306.27 | smp(ms): 0.85 | fwd(ms): 36.51 | crit-fwd(ms): 0.50 | bwd(ms): 250.96 | optim(ms): 15.03 | loss:    2.07057 | train-TER: 14.47 | train-WER: 25.12 | dev-loss:    0.10144 | dev-TER:  1.71 | dev-WER:  3.55 | vali-loss:    0.24045 | vali-TER:  2.42 | vali-WER:  5.27 | avg-isz: 274 | avg-tsz: 018 | max-tsz: 072 | hrs:   24.38 | thrpt(sec/sec): 286.61

I think "continue" mode has still some bug. Looking through it..

deepspiking commented 4 years ago

I found that this worked for my case: In train.cpp, change this

if (runStatus == kTrainMode || runStatus == kForkMode) {
  netoptim = initOptimizer(
      {network}, FLAGS_netoptim, FLAGS_lr, FLAGS_momentum, FLAGS_weightdecay);
  critoptim =
      initOptimizer({criterion}, FLAGS_critoptim, FLAGS_lrcrit, 0.0, 0.0);
}

to

if (runStatus == kTrainMode || runStatus == kForkMode || runStatus == kContinueMode) {
  netoptim = initOptimizer(
      {network}, FLAGS_netoptim, FLAGS_lr, FLAGS_momentum, FLAGS_weightdecay);
  critoptim =
      initOptimizer({criterion}, FLAGS_critoptim, FLAGS_lrcrit, 0.0, 0.0);
}

And rebuild.

$ cd build;make -j 8

Could you try this?

TijsRozenbroek commented 4 years ago

Glad it fixed your problem, unfortunately it didn't fix mine.

Now the log for epoch 201 is as follows:

epoch:      201 | nupdates:      1023795 | lr: 0.000154 | lrcriterion: 0.000154 | runtime: 00:04:28 | bch(ms): 52.64 | smp(ms): 0.23 | fwd(ms): 19.23 | crit-fwd(ms): 2.45 | bwd(ms): 26.86 | optim(ms): 5.80 | loss:   75.34230 | train-LER: 32.27 | train-WER: 38.92 | atccvalid-loss:  187.23566 | atccvalid-LER: 79.39 | atccvalid-WER: 89.88 | atcosimvalid-loss:   15.91293 | atcosimvalid-LER: 18.58 | atcosimvalid-WER: 24.52 | avg-isz: 892 | avg-tsz: 144 | max-tsz: 5996 | hrs:   25.25 | thrpt(sec/sec): 338.95
deepspiking commented 4 years ago

Good to here that:) It would be appreciated if you clarify what the problem was in your case.

TijsRozenbroek commented 4 years ago

I'm afraid you misread my comment, the problem described in my opening comment still persists after trying your fix.

deepspiking commented 4 years ago

Oh I misread ;) Have you ever tried fork mode instead of continue mode?

TijsRozenbroek commented 4 years ago

Yes, as I mentioned in my initial post, I did. See the quote below.

When I use fork, the learning rate is correct, however I think using the continue mode would be better, as it is intended exactly for this purpose, as stated in the wiki:

Continue training a saved model. This can be used for example to fine-tune with a smaller learning rate. The continue option makes a best effort to resume training from the most recent checkpoint of a given model as if there were no interruptions.

tlikhomanenko commented 4 years ago

@TijsRozenbroek,

Try to add after this line https://github.com/facebookresearch/wav2letter/blob/master/Train.cpp#L250

netoptim->setLr(FLAGS_lr);

However, for this case we implement in the way to use fork, because then it is simpler to reproduce what model, how long, with which lr one trained (not parsing the logs and checking lr there) and also maybe one changed other parameters too (another scheduling of lr/momentum, etc).

TijsRozenbroek commented 4 years ago

Hi, thanks for coming to help out.

When I try your fix by adding that line to Train.cpp and rebuilding, the 201st epoch log is as follows:

epoch:      201 | nupdates:      1023795 | lr: 0.000154 | lrcriterion: 0.000182 | runtime: 00:04:30 | bch(ms): 53.17 | smp(ms): 0.19 | fwd(ms): 19.43 | crit-fwd(ms): 2.47 | bwd(ms): 27.03 | optim(ms): 5.97 | loss:   75.33963 | train-LER: 32.27 | train-WER: 38.92 | atccvalid-loss:  187.23625 | atccvalid-LER: 79.39 | atccvalid-WER: 89.88 | atcosimvalid-loss:   15.93031 | atcosimvalid-LER: 18.59 | atcosimvalid-WER: 24.54 | avg-isz: 892 | avg-tsz: 144 | max-tsz: 5996 | hrs:   25.25 | thrpt(sec/sec): 335.61

This is equal to the learning rate after trying the fix suggested by @deepspiking and thus still not correct. (As a side note, you can also see that lrcriterion is different, I suppose critoptim->setLr(FLAGS_lrcrit); should be added to fix that.)

So unfortunately the issue still persists. I will continue using fork for the time being.

tlikhomanenko commented 4 years ago

Yep, with critoptim->setLr(FLAGS_lrcrit); you should be able to change criterion learning rate too (forgot that you have s2s criterion, not ctc - for ctc there is no parameters to learn).

@vineelpratap, any idea why these two fixes above doesn't help?

padentomasello commented 4 years ago

hi @TijsRozenbroek, I see how this is confusing. I think continue was not really designed to for changing configs, and we should update our wiki accordingly.

For your case, I think the problem is that we're correctly setting the initial learning rate, but we're still using the learning rate schedule to calculate the lr for this particular update (1023795).

epoch: 201 | nupdates: 1023795 | lr: 0.000154 | lrcriterion: 0.000182 | runtime: 00:04:30 | bch(ms): 53.17 | smp(ms): 0.19 | fwd(ms): 19.43 | crit-fwd(ms): 2.47 | bwd(ms): 27.03 | optim(ms): 5.97 | loss: 75.33963 | train-LER: 32.27 | train-WER: 38.92 | atccvalid-loss: 187.23625 | atccvalid-LER: 79.39 | atccvalid-WER: 89.88 | atcosimvalid-loss: 15.93031 | atcosimvalid-LER: 18.59 | atcosimvalid-WER: 24.54 | avg-isz: 892 | avg-tsz: 144 | max-tsz: 5996 | hrs: 25.25 | thrpt(sec/sec): 335.61

lr = (lr_gamma ^ (num_updates / lr_step_size) * init_lr = 0.5^(1023795/203740)*0.005 =0.000154 So the initial learning rate is set correctly, but we're applying a learning rate schedule as if we were on update 1023795.

For your case, I think fork is more appropriate for what you are trying to do. However, If you really wanted to use continue I suppose you could also set the lr_gamma to 1, so our learning rate does not decay, or set the initial learning rate so that the learning rate is equal to 0.005 at update 1023795.

TijsRozenbroek commented 4 years ago

Hi @padentomasello, thanks for your answer! It really cleared up my confusion.

I'll close this issue now. It would indeed be great if the wiki could be updated accordingly to prevent confusion for others in the future.