facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.48k stars 6.41k forks source link

HuBERT - cannot finetune model #3864

Open assafmu opened 3 years ago

assafmu commented 3 years ago

🐛 Bug

After HuBERT model is loaded, cannot fine tune using the base_10h config. Training fails with "TypeError: tuple indices must be integers or slices, not str".

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

Without the unrelated flags, the command is: fairseq-hydra-train --config-dir examples/wav2vec/config/finetuning --config-name base_10h task.data=MY_DATA_DIR model.w2v_path=MY_MODEL_PATH

For the record, my exact command was:

fairseq-hydra-train --config-dir examples/wav2vec/config/finetuning --config-name base_10h task.data=MY_DATA_DIR common.tensorboard_logdir=tb model.w2v_path=MY_MODEL_PATH distributed_training.distributed_world_size=1 common.fp16=False model.freeze_finetune_updates=10 dataset.valid_subset=valid

The model used was the base (960ls) model.

stacktrace: [2021-09-12 15:49:06,942][fairseq.trainer][INFO] - begin training epoch 1 [2021-09-12 15:49:06,943][fairseq_cli.train][INFO] - Start iterating over samples Traceback (most recent call last): File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 28, in hydra_main _hydra_main(cfg) File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 53, in _hydra_main distributed_utils.call_main(cfg, pre_main, kwargs) File "/home/assafmushkin/fairseq/fairseq/distributed/utils.py", line 369, in call_main main(cfg, kwargs) File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 180, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner return func(*args, kwds) File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 291, in train log_output = trainer.train_step(samples) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner return func(*args, *kwds) File "/home/assafmushkin/fairseq/fairseq/trainer.py", line 761, in train_step loss, sample_size_i, logging_output = self.task.train_step( File "/home/assafmushkin/fairseq/fairseq/tasks/fairseq_task.py", line 492, in train_step loss, sample_size, logging_output = criterion(model, sample) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/criterions/ctc.py", line 110, in forward net_output = model(sample["net_input"]) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 209, in forward x = self.w2v_encoder(kwargs) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 402, in forward x = res["x"] TypeError: tuple indices must be integers or slices, not str

Code sample

Expected behavior

Extract features correctly, allowing for finetuning.

Environment

Additional context

It seems the problem is with the expected types returned from the HuBERT extract_features and the Wav2Vec extract_features. The first returns a tuple (the first element of which is a dictionary) and the second returns a dictionary. This causes a type error in the forward of wav2vec2_asr which expects the wav2vec2 types.

wnhsu commented 3 years ago

@assafmu you need to use HuBERT's fine-tuning config as mentioned in https://github.com/pytorch/fairseq/tree/main/examples/hubert#fine-tune-a-hubert-model-with-a-ctc-loss