facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.54k stars 6.41k forks source link

wav2vec training crashes with AssertionError #1610

Closed alpoktem closed 4 years ago

alpoktem commented 4 years ago

🐛 Bug

I start the training following wav2vec training guidelines. 84% through the first epoch it quits giving an AssertionError.

To Reproduce

python $FAIRSEQ/train.py $WORKDIR/manifest --save-dir $WORKDIR/models/ \
--num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
--arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam \
--max-lr 0.005 --lr-scheduler cosine --conv-feature-layers "[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)]" \
--conv-aggregator-layers "[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]" \
--skip-connections-agg --residual-scale 0.5 --log-compression \
--warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy \
--num-negatives 10 --max-sample-size 150000 --max-tokens 1500000 \
--skip-invalid-size-inputs-valid-test

Error output

Traceback (most recent call last):
  File "/home/twbgmy/extSW/fairseq/train.py", line 363, in <module>
    cli_main()
  File "/home/twbgmy/extSW/fairseq/train.py", line 355, in cli_main
    nprocs=args.distributed_world_size,
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/twbgmy/extSW/fairseq/train.py", line 322, in distributed_main
    main(args, init_distributed=True)
  File "/home/twbgmy/extSW/fairseq/train.py", line 89, in main
    train(args, trainer, task, epoch_itr)
  File "/home/twbgmy/extSW/fairseq/train.py", line 153, in train
    log_output = trainer.train_step(samples)
  File "/home/twbgmy/extSW/fairseq/fairseq/trainer.py", line 327, in train_step
    sample, self.model, self.criterion, self.optimizer, ignore_grad
  File "/home/twbgmy/extSW/fairseq/fairseq/tasks/fairseq_task.py", line 290, in train_step
    loss, sample_size, logging_output = criterion(model, sample)
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/twbgmy/extSW/fairseq/fairseq/criterions/binary_cross_entropy.py", line 30, in forward
    net_output = model(**sample['net_input'])
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 442, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/twbgmy/extSW/fairseq/fairseq/models/wav2vec.py", line 183, in forward
    x, targets = self.wav2vec_predictions(x, features)
  File "/home/twbgmy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/twbgmy/extSW/fairseq/fairseq/models/wav2vec.py", line 429, in forward
      assert end == predictions.numel(), '{} != {}'.format(end, predictions.numel())
AssertionError: 0 != 517

Environment

ahmedalbahnasawy commented 4 years ago

I tried the pertained model you can find it here https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md did you try to concatenate the contextual vector with the spectrogram and feed it to the ASR network ? Or did you try replace the spectogram with the context vector Thank you

alpoktem commented 4 years ago

I also used the pre-trained model and works fine. I plugged the h5 to wav2letter cancelling mfcc's. I am having problems training my own model. It's actually training well with Librispeech data without problem. I suppose I have a problem in my data.

ahmedalbahnasawy commented 4 years ago

@alpoktem i don't think its problem in your data. have you try using other acoustic models such as :DeepSpeech, KALD and ESPNET ? I think the problem because of the large number of features from the wav2vec model [batch_size, audio frame, 512] to any acoustic model ! please correct me if i'm wrong Thanks

Oktai15 commented 4 years ago

@alpoktem hi! can you provide loss curve of your training?

jfhou commented 4 years ago

@alpoktem I met the same error. In my case, the AssertionError is caused by a very short utterance where "steps" variable in wav2vec.py:"steps = min(steps, tsz - self.offset)" is negative and thus "end" is 0. After I filtered out short utterances, the AssertionError disappeared. I think maybe it is a bug of the code.