flashlight / wav2letter

Facebook AI Research's Automatic Speech Recognition Toolkit
https://github.com/facebookresearch/wav2letter/wiki
Other
6.39k stars 1.01k forks source link

ASG Criterion Training Speed #259

Closed apldev3 closed 5 years ago

apldev3 commented 5 years ago

Hey all just wanted to ask if anyone's noted any issues using the config from the Tutorial, changing the criteria from "CTC" to "ASG" and then seeing an extreme decline in train times.

Specifically, I've seen my epoch's go from training in 1 minute on CTC to ~20 minutes on ASG, which makes quite a bit of difference when trying to do the 100 iterations of training recommended by the tutorial. This leaves me to think that either ASG doesn't/shouldn't require as many iterations as CTC or my Architecture isn't optimal for using ASG.

Does anyone have any quick recommendations?

lunixbochs commented 5 years ago

I was able to train their librispeech recipe to 5% TER in 5 epochs on ASG. What kind of sec/sec throughput are you seeing? I'm getting >40sec/sec on a single V100 with ASG

apldev3 commented 5 years ago

So I'm running ASG on a single Tesla M60 and seeing ~35.9 sec/sec. Each epoch takes ~21:55 minutes. This is consistent for pretty much all of the currently 61 epochs.

This is the output from the log for the latest epoch

epoch:       61 | lr: 0.100000 | lrcriterion: 0.000000 | runtime: 00:21:55 | bch(ms): 895.79 | smp(ms): 0.30 | fwd(ms): 276.49 | crit-fwd(ms): 270.98                 | bwd(ms): 567.39 | optim(ms): 2.04 | loss:    0.38910 | train-TER:  1.32 | validation-TER:  8.70 | avg-isz: 804 | avg-tsz: 108 | max-tsz: 143 | hrs                :   13.12 | thrpt(sec/sec): 35.91
lunixbochs commented 5 years ago

Your loss is very low and your training error rate is much lower than your validation error rate. This network might be overfitted already. What sort of loss/TER/ sec/sec did you see on CTC?

apldev3 commented 5 years ago

This is from the final (100th) epoch of CTC training on the same dataset.

epoch:      100 | lr: 0.100000 | lrcriterion: 0.000000 | runtime: 00:01:07 | bch(ms): 45.64 | smp(ms): 2.87 | fwd(ms): 27.46 | crit-fwd(ms): 21.71 | bwd(ms): 11.05 | optim(ms): 2.40 | loss:    0.15229 | train-TER:  0.47 | validation-TER:  6.94 | avg-isz: 804 | avg-tsz: 108 | max-tsz: 143 | hrs:   13.12 | thrpt(sec/sec): 704.83

Based off your comment this is even worse in terms of overfitting I would imagine. What would you say is an idealish loss/TER?

lunixbochs commented 5 years ago

My librispeech training is around 4% TER on test and 7% on train. It’s fine to have a low loss, I think your loss getting super low just means the network is pretty much done without much room for improvement. So when your loss is very low, and your training TER is basically zero, and your test TER isn’t very good - you’ve likely overfit on the training set.

At that point I think you need more diverse training data, or do some data augmentation, or just stop training before it starts to overfit. If you look at the log, does the test TER plateau long before the train TER stops minimizing?

apldev3 commented 5 years ago

At 100 iterations yeah it definitely has diminishing returns but the test TER doesn't plateau until around 40-50ish epochs. Decreases a small bit more there but I see what you mean. I've dropped the iterations down to 20 from 100 and I'm seeing this from training on 25% of our dataset.

epoch:       20 | lr: 0.100000 | lrcriterion: 0.000000 | runtime: 00:21:52 | bch(ms): 893.97 | smp(ms): 0.27 | fwd(ms): 275.29 | crit-fwd(ms): 269.79 | bwd(ms): 566.84 | optim(ms): 2.00 | loss:    2.46040 | train-TER:  7.34 | validation-TER: 12.51 | avg-isz: 804 | avg-tsz: 108 | max-tsz: 143 | hrs:   13.12 | thrpt(sec/sec): 35.99

This is using 50% of our dataset.

epoch:       20 | lr: 0.100000 | lrcriterion: 0.000000 | runtime: 00:43:39 | bch(ms): 892.49 | smp(ms): 0.26 | fwd(ms): 275.48 | crit-fwd(ms): 270.00 | bwd(ms): 565.46 | optim(ms): 1.99 | loss:    1.28024 | train-TER:  3.58 | validation-TER:  6.93 | avg-isz: 801 | avg-tsz: 108 | max-tsz: 143 | hrs:   26.13 | thrpt(sec/sec): 35.92

Hopefully these are more reasonable results that aren't quite as indicative of overfitting.

apldev3 commented 5 years ago

Alrighty, well I do appreciate the help with this ticket. I'm going to go ahead and mark it as closed since 100 iterations was clearly too many, and it was a simple misunderstanding on my part!