facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.43k stars 6.4k forks source link

Wav2Vec2.0 state of saved model missing in checkpoint after pre-training and code breaks in fine-tuning. #2807

Closed amant555 closed 4 years ago

amant555 commented 4 years ago

🐛 Bug

I have pre-trained the model and when I used the best checkpoint to fine-tune it. I received the given error.

"/home/aman_tiwari/wav2vec2/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 337, in __init__
    args.normalize == w2v_args.normalize
AttributeError: 'NoneType' object has no attribute 'normalize'

After further investigating the code in wav2vec2_asr.py, I found that the state is being fetched from args in saved checkpoint. Which has none value. And its same for all the checkpoints that are generated. I have also analysed the checkpoint given in Readme on wav2vec page. The args values in that are properly populated. Also, there is an additional key value (cfg) in newly generated checkpoints.

Code start working again once I change the state that is being fetched from state["args"] to state["cfg"]["model"] in wav2vec2_asr.py

Complete error traceback

Traceback (most recent call last):
  File "train.py", line 14, in <module>
    cli_main()
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq_cli/train.py", line 352, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq/distributed_utils.py", line 317, in call_main
    main(cfg, **kwargs)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq_cli/train.py", line 74, in main
    model = task.build_model(cfg.model)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq/tasks/fairseq_task.py", line 548, in build_model
    model = models.build_model(args, self)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq/models/__init__.py", line 56, in build_model
    return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 166, in build_model
    w2v_encoder = Wav2VecEncoder(args, task.target_dictionary)
  File "/home/aman_tiwari/wav2vec2/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 337, in __init__
    args.normalize == w2v_args.normalize
AttributeError: 'NoneType' object has no attribute 'normalize'

Command used for FINE-TUNING

python train.py --distributed-world-size 8 --distributed-port -1 /home/aman_tiwari/wav2vec2/fairseq/prep_scripts --save-dir /home/aman_tiwari/wav2vec2/fairseq/finetune_checkpoints --fp16 \
    --wer-args '("/home/aman_tiwari/wav2vec2/fairseq/prep_scripts/lm_3.binary","/home/aman_tiwari/wav2vec2/fairseq/prep_scripts/lexicon.lst",2,-1)' \
    --post-process letter --valid-subset valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 12 \
    --max-update 80000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /home/aman_tiwari/wav2vec2/fairseq/pretrainning_checkpoint/checkpoint_best.pt \
    --labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
    --mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 --zero-infinity \
    --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 --optimizer adam \
    --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 --hold-steps 32000 \
    --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 --activation-dropout 0.1 --criterion ctc \
    --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json --log-interval 500 --ddp-backend no_c10d

Command used for PRE-TRAINING

python train.py --distributed-world-size 8 --distributed-port -1 /home/aman_tiwari/manifests_temp/ \
--save-dir /home/aman_tiwari/checkpoints/ --fp16 --device-id 0 --num-workers 32 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \
--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \
--latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \
--adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 400000 \
--lr 0.00005 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \
--encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \
--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 \
--max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--max-tokens 1400000 --max-update 400000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d  --tensorboard-logdir /home/aman_tiwari/tensorboard

Questions

  1. Is it okay to utilise the state["cfg"]["model"] to fetch model state during fine-tuning?
  2. If not, What are the steps that I can use to resolve this?
medabalimi commented 4 years ago

does adding --normalize help

myleott commented 4 years ago

Yes, I suspect that state["cfg"]["model"] is correct, since we recently migrated to Hydra configuration which uses "cfg" instead of "args". Can you please submit a PR with the fix?

amant555 commented 4 years ago

Hi, @myleott I have submitted the PR with the fix.

amant555 commented 4 years ago

@alexeib small question In the fine-tuning with LM. I am getting Error in `python': corrupted double-linked list: 0x000055d8f2c35210. Attached the full traceback LM_error_traceback.txt . This LM is created using kenlm with bin/lmplz -o 5 <text >text.arpa > bin/build_binary trie text.arpa text.binary And I have tried one without trie too. Can you see and tell me what I am doing wrong and suggest a possible fix?

alexeib commented 4 years ago

@amant555 thx for the fix, i added a comment to your pr re: your crash, i see there is a warning printed at the very top to install wav2letter bindings. did you do this? if you go to python repl and type "import wav2letter", does it work?

amant555 commented 4 years ago

Yeah import wav2letter works in repl. Warning is because code is not able to import LexiconFreeDecoder from wav2letter. As its not being used by kenlm so I think that won't be the reason for error, right?

alexeib commented 4 years ago

can you comment out LexiconFreeDecoder import then?

amant555 commented 4 years ago

I did that too. But error didn't change.

wahyubram82 commented 3 years ago

@alexeib small question In the fine-tuning with LM. I am getting Error in `python': corrupted double-linked list: 0x000055d8f2c35210. Attached the full traceback LM_error_traceback.txt . This LM is created using kenlm with bin/lmplz -o 5 <text >text.arpa > bin/build_binary trie text.arpa text.binary And I have tried one without trie too. Can you see and tell me what I am doing wrong and suggest a possible fix?

happens to me too...

I already investigate it, it happens when this script execute in line 156.

it an iteration of lexicon word index, the spelling index and the scorer, and while iterating, the script order to input all of that to the trie with this command: self.trie.insert(spelling_idxs, word_idx, score) the error occurs when that script execute. It's not because of the word or the spelling in the lexicon.txt

please read this, to understood what happens... It's clear not because of the words/spelling in the lexicon, it's because of something else...

because some times, a words / spelling, can inserting to the trie, and in other times, it's cannot.

I already try to refresh the fairseq to v0.10.0, problems still occurs...

it's double error reported, sometimes segmentation fault (core dump), some times(corrupted double-linked list)..

or memory problem?

pkadambi commented 3 years ago

@wahyubram82 were you able to figure out what the issue was? I'm running into the same problem here

osddeitf commented 3 years ago

@pkadambi may be your wav2letter installation was the culprit. See https://github.com/pytorch/fairseq/issues/2493#issuecomment-755035532.