facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.54k stars 6.41k forks source link

ADAM KeyError: 'max_exp_avg_sq' when further pretraining XLSR 53 on Portuguese Data #3965

Open fmobrj opened 3 years ago

fmobrj commented 3 years ago

šŸ› Bug

When I run the fairseq-hydra-train script with pretraining config (large), but loading the XLSR 53 checkpoint, I get a KeyError: 'max_exp_avg_sq' on fairseq/optim/adam.py after some training time. When I run without loading the checkpoint, the script runs without problems.

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

fairseq-hydra-train \
    task.data=/media/hdd3tb/pyinstalls/fairseq/manifest/cv_pt_7 \
    checkpoint.restore_file=/media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt \
    distributed_training.distributed_world_size=1 +optimization.update_freq='[128]' \
    --config-dir /media/hdd3tb/pyinstalls/fairseq/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_large_librivox

The initial log:

[2021-10-17 09:13:08,479][fairseq_cli.train][INFO] - task: AudioPretrainingTask
[2021-10-17 09:13:08,480][fairseq_cli.train][INFO] - model: Wav2Vec2Model
[2021-10-17 09:13:08,480][fairseq_cli.train][INFO] - criterion: Wav2vecCriterion
[2021-10-17 09:13:08,481][fairseq_cli.train][INFO] - num. shared model params: 317,390,592 (num. trained: 317,390,592)
[2021-10-17 09:13:08,482][fairseq_cli.train][INFO] - num. expert model params: 0 (num. trained: 0)
[2021-10-17 09:13:08,484][fairseq.data.audio.raw_audio_dataset][INFO] - loaded 942, skipped 1 samples
[2021-10-17 09:13:11,216][fairseq.utils][INFO] - ***********************CUDA enviroments for all 1 workers***********************
[2021-10-17 09:13:11,217][fairseq.utils][INFO] - rank   0: capabilities =  6.1  ; total memory = 10.915 GB ; name = GeForce GTX 1080 Ti
[2021-10-17 09:13:11,217][fairseq.utils][INFO] - ***********************CUDA enviroments for all 1 workers***********************
[2021-10-17 09:13:11,217][fairseq_cli.train][INFO] - training on 1 devices (GPUs/TPUs)
[2021-10-17 09:13:11,217][fairseq_cli.train][INFO] - max tokens per device = 1200000 and max sentences per device = 4
[2021-10-17 09:13:11,217][fairseq.trainer][INFO] - Preparing to load checkpoint /media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt
[2021-10-17 09:13:12,962][fairseq.trainer][INFO] - NOTE: your device does NOT support faster training with --fp16 or --amp, please switch to FP32 which is likely to be faster
[2021-10-17 09:13:14,480][fairseq.trainer][INFO] - Loaded checkpoint /media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt (epoch 19 @ 800000 updates)
[2021-10-17 09:13:14,481][fairseq.trainer][INFO] - loading train data for epoch 19
[2021-10-17 09:13:14,556][fairseq.data.audio.raw_audio_dataset][INFO] - loaded 93303, skipped 19 samples
[2021-10-17 09:13:16,619][fairseq.trainer][INFO] - begin training epoch 19
[2021-10-17 09:13:16,620][fairseq_cli.train][INFO] - Start iterating over samples

The error:

Traceback (most recent call last):
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq_cli/hydra_train.py", line 28, in hydra_main
    _hydra_main(cfg)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq_cli/hydra_train.py", line 53, in _hydra_main
    distributed_utils.call_main(cfg, pre_main, **kwargs)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 369, in call_main
    main(cfg, **kwargs)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq_cli/train.py", line 180, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq_cli/train.py", line 291, in train
    log_output = trainer.train_step(samples)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/trainer.py", line 872, in train_step
    self.task.optimizer_step(
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/tasks/fairseq_task.py", line 506, in optimizer_step
    optimizer.step()
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 215, in step
    self.fp32_optimizer.step(closure, groups=groups)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/optim/fairseq_optimizer.py", line 127, in step
    self.optimizer.step(closure)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/jamazzon/anaconda3/envs/wave2vec/lib/python3.9/site-packages/fairseq/optim/adam.py", line 203, in step
    state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
KeyError: 'max_exp_avg_sq'

Additional context

When I run without loading the checkpoint downloaded from the fairseq link (xlsr_53_56k.pt), I dont get any errors.

fmobrj commented 3 years ago

I adjusted the arguments to reset some parameters (dataloader, lr_scheduler, meters and optimizer), and now I don't get any more errors. But I am not sure this approach is correct or optimal to further pretrain the xlsr 53 checkpoint with my data.

The execution, now, starts from epoch 1, instead of epoch 19 when training without reseting these parameters.

Command line:

fairseq-hydra-train \
    task.data=/media/hdd3tb/pyinstalls/fairseq/manifest/cv_pt_7 \
    checkpoint.restore_file=/media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt \
    checkpoint.reset_dataloader=True \
    checkpoint.reset_lr_scheduler=True \
    checkpoint.reset_meters=True \
    checkpoint.reset_optimizer=True \
    distributed_training.distributed_world_size=1 +optimization.update_freq='[128]' \
    --config-dir /media/hdd3tb/pyinstalls/fairseq/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_large_librivox

But now I get a lot of gradient overflow messages. I don't know if I am doing this right.

[2021-10-17 09:29:05,518][fairseq_cli.train][INFO] - task: AudioPretrainingTask
[2021-10-17 09:29:05,518][fairseq_cli.train][INFO] - model: Wav2Vec2Model
[2021-10-17 09:29:05,518][fairseq_cli.train][INFO] - criterion: Wav2vecCriterion
[2021-10-17 09:29:05,519][fairseq_cli.train][INFO] - num. shared model params: 317,390,592 (num. trained: 317,390,592)
[2021-10-17 09:29:05,520][fairseq_cli.train][INFO] - num. expert model params: 0 (num. trained: 0)
[2021-10-17 09:29:05,522][fairseq.data.audio.raw_audio_dataset][INFO] - loaded 942, skipped 1 samples
[2021-10-17 09:29:08,262][fairseq.utils][INFO] - ***********************CUDA enviroments for all 1 workers***********************
[2021-10-17 09:29:08,262][fairseq.utils][INFO] - rank   0: capabilities =  6.1  ; total memory = 10.915 GB ; name = GeForce GTX 1080 Ti
[2021-10-17 09:29:08,262][fairseq.utils][INFO] - ***********************CUDA enviroments for all 1 workers***********************
[2021-10-17 09:29:08,262][fairseq_cli.train][INFO] - training on 1 devices (GPUs/TPUs)
[2021-10-17 09:29:08,263][fairseq_cli.train][INFO] - max tokens per device = 1200000 and max sentences per device = 4
[2021-10-17 09:29:08,263][fairseq.trainer][INFO] - Preparing to load checkpoint /media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt
[2021-10-17 09:29:10,013][fairseq.trainer][INFO] - NOTE: your device does NOT support faster training with --fp16 or --amp, please switch to FP32 which is likely to be faster
[2021-10-17 09:29:10,025][fairseq.trainer][INFO] - Loaded checkpoint /media/hdd3tb/data/wav2vec_data/pretrained_ckpt/xlsr_53_56k.pt (epoch 19 @ 0 updates)
[2021-10-17 09:29:10,025][fairseq.trainer][INFO] - loading train data for epoch 1
[2021-10-17 09:29:10,100][fairseq.data.audio.raw_audio_dataset][INFO] - loaded 93303, skipped 19 samples
[2021-10-17 09:29:12,172][fairseq.trainer][INFO] - begin training epoch 1
[2021-10-17 09:29:12,172][fairseq_cli.train][INFO] - Start iterating over samples
[2021-10-17 09:30:19,792][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 64.0
[2021-10-17 09:31:26,843][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 32.0
[2021-10-17 09:32:32,447][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 16.0
[2021-10-17 09:33:39,184][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 8.0
[2021-10-17 09:34:48,172][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 4.0
[2021-10-17 09:35:52,963][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
[2021-10-17 09:36:59,297][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
mzboito commented 3 years ago

Hello, do you have any update regarding this? I'm having the exact same problem trying to continue the XLSR-53 pretraining on some new data. While I don't doubt that adding checkpoint.reset_lr_scheduler=True and checkpoint.reset_optimizer=True might solve the launching problem, I wonder the impact resetting the LR scheduler and optimizer will have in the performance of the final model. Any insights?