k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
143 stars 42 forks source link

Mixed precision training #171

Closed pzelasko closed 3 years ago

pzelasko commented 3 years ago

My initial estimate is that we can use about 1.5x larger batches per GPU with these changes. I was using max_duration=450, and so far 200 batches into training with max_duration=700 it seems to be working.

I'll train a full model and report the results. They shouldn't be affected.

pzelasko commented 3 years ago

Update: 3 epochs into the training, it seems max_duration=650 is appropriate for 32GB GPU, and 4xV100 + AMP allows to finish one epoch in ~1h:10min (these are actually 3 epochs since we are using speed perturbation in prepare.py). This is a ~45% effective batch size increase and a ~30% faster training than without AMP. I will report the results once it finishes.

BTW your mileage may vary depending on the GPU - I think V100, RTX 2080Ti and newer have hardware acceleration for fp16, whereas GTX1080Ti will allow bigger batches but you wouldn't see that much speed-up.

pzelasko commented 3 years ago

Without AMP (max_duration=450):

Epoch 1:
2021-04-20 12:41:07,604 INFO [common.py:373] [test-clean] %WER 7.95% [4179 / 52576, 541 ins, 366 del, 3272 sub ]
2021-04-20 12:42:05,229 INFO [common.py:373] [test-other] %WER 17.39% [9104 / 52343, 1183 ins, 802 del, 7119 sub ]
Epoch 2:
2021-04-20 14:11:25,993 INFO [common.py:373] [test-clean] %WER 6.51% [3425 / 52576, 348 ins, 323 del, 2754 sub ]
2021-04-20 14:12:18,618 INFO [common.py:373] [test-other] %WER 13.81% [7227 / 52343, 743 ins, 789 del, 5695 sub ]
Epoch 3:
2021-04-20 16:13:06,685 INFO [common.py:373] [test-clean] %WER 5.68% [2985 / 52576, 317 ins, 271 del, 2397 sub ]
2021-04-20 16:13:58,784 INFO [common.py:373] [test-other] %WER 12.29% [6431 / 52343, 708 ins, 600 del, 5123 sub ]
Epoch 4:
2021-04-20 19:30:19,240 INFO [common.py:373] [test-clean] %WER 5.56% [2922 / 52576, 319 ins, 288 del, 2315 sub ]
2021-04-20 19:31:07,311 INFO [common.py:373] [test-other] %WER 11.71% [6129 / 52343, 722 ins, 535 del, 4872 sub ]
Epoch 5:
2021-04-20 19:32:31,436 INFO [common.py:373] [test-clean] %WER 5.26% [2767 / 52576, 327 ins, 230 del, 2210 sub ]
2021-04-20 19:33:20,105 INFO [common.py:373] [test-other] %WER 11.02% [5770 / 52343, 729 ins, 442 del, 4599 sub ]
Epoch 6:
2021-04-20 20:54:47,643 INFO [common.py:373] [test-clean] %WER 5.26% [2764 / 52576, 312 ins, 248 del, 2204 sub ]
2021-04-20 20:55:39,609 INFO [common.py:373] [test-other] %WER 10.90% [5706 / 52343, 667 ins, 501 del, 4538 sub ]

Average (epochs 4, 5, 6)
2021-04-20 20:55:05,346 INFO [common.py:373] [test-clean] %WER 4.99% [2624 / 52576, 303 ins, 227 del, 2094 sub ]
2021-04-20 20:55:51,756 INFO [common.py:373] [test-other] %WER 10.21% [5346 / 52343, 662 ins, 412 del, 4272 sub ]

Average (epochs 4, 5, 6) with rescoring
2021-04-20 21:08:11,565 INFO [common.py:373] [test-clean] %WER 4.35% [2285 / 52576, 399 ins, 128 del, 1758 sub ]
2021-04-20 21:15:56,691 INFO [common.py:373] [test-other] %WER 8.90% [4656 / 52343, 804 ins, 234 del, 3618 sub ]

With AMP (max_duration=650, I actually forgot to adjust LR for larger batch size, but it ended up fine...):

Epoch 1:
2021-04-21 17:09:31,498 INFO [common.py:380] [test-clean] %WER 8.92% [4692 / 52576, 549 ins, 450 del, 3693 sub ]
2021-04-21 17:10:36,991 INFO [common.py:380] [test-other] %WER 18.96% [9925 / 52343, 1202 ins, 887 del, 7836 sub ]
Epoch 2:
2021-04-21 19:29:23,523 INFO [common.py:380] [test-clean] %WER 6.36% [3346 / 52576, 361 ins, 310 del, 2675 sub ]
2021-04-21 19:30:22,366 INFO [common.py:380] [test-other] %WER 14.44% [7556 / 52343, 780 ins, 821 del, 5955 sub ]
Epoch 3:
2021-04-21 22:11:10,664 INFO [common.py:380] [test-clean] %WER 5.86% [3079 / 52576, 368 ins, 258 del, 2453 sub ]
2021-04-21 22:12:02,592 INFO [common.py:380] [test-other] %WER 12.83% [6716 / 52343, 820 ins, 512 del, 5384 sub ]
Epoch 4:
2021-04-21 22:13:41,972 INFO [common.py:380] [test-clean] %WER 5.51% [2896 / 52576, 323 ins, 256 del, 2317 sub ]
2021-04-21 22:14:36,864 INFO [common.py:380] [test-other] %WER 11.75% [6152 / 52343, 717 ins, 508 del, 4927 sub ]
Epoch 5:
2021-04-21 22:16:00,040 INFO [common.py:380] [test-clean] %WER 5.44% [2859 / 52576, 348 ins, 270 del, 2241 sub ]
2021-04-21 22:16:49,753 INFO [common.py:380] [test-other] %WER 11.70% [6124 / 52343, 798 ins, 485 del, 4841 sub ]
Epoch 6:
2021-04-22 07:37:04,487 INFO [common.py:380] [test-clean] %WER 5.30% [2786 / 52576, 296 ins, 302 del, 2188 sub ]
2021-04-22 07:38:26,079 INFO [common.py:380] [test-other] %WER 10.78% [5645 / 52343, 621 ins, 510 del, 4514 sub ]

Average (epochs 4, 5, 6)
2021-04-22 07:36:42,281 INFO [common.py:380] [test-clean] %WER 4.93% [2593 / 52576, 310 ins, 223 del, 2060 sub ]
2021-04-22 07:37:35,828 INFO [common.py:380] [test-other] %WER 10.21% [5343 / 52343, 673 ins, 411 del, 4259 sub ]

Average (epochs 4, 5, 6) with rescoring
2021-04-22 07:42:46,456 INFO [common.py:380] [test-clean] %WER 4.25% [2234 / 52576, 395 ins, 123 del, 1716 sub ]
2021-04-22 07:50:16,896 INFO [common.py:380] [test-other] %WER 9.02% [4720 / 52343, 801 ins, 265 del, 3654 sub ]
pzelasko commented 3 years ago

I resolved the conflicts with ali model PR; it is ready to merge.

danpovey commented 3 years ago

Cool! Merging, LGTM!