After HuBERT model is loaded, cannot fine tune using the base_10h config. Training fails with "TypeError: tuple indices must be integers or slices, not str".
To Reproduce
Steps to reproduce the behavior (always include the command you ran):
Without the unrelated flags, the command is:
fairseq-hydra-train --config-dir examples/wav2vec/config/finetuning --config-name base_10h task.data=MY_DATA_DIR model.w2v_path=MY_MODEL_PATH
stacktrace:
[2021-09-12 15:49:06,942][fairseq.trainer][INFO] - begin training epoch 1
[2021-09-12 15:49:06,943][fairseq_cli.train][INFO] - Start iterating over samples
Traceback (most recent call last):
File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 28, in hydra_main
_hydra_main(cfg)
File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 53, in _hydra_main
distributed_utils.call_main(cfg, pre_main, kwargs)
File "/home/assafmushkin/fairseq/fairseq/distributed/utils.py", line 369, in call_main
main(cfg, kwargs)
File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 180, in main
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, kwds)
File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 291, in train
log_output = trainer.train_step(samples)
File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, *kwds)
File "/home/assafmushkin/fairseq/fairseq/trainer.py", line 761, in train_step
loss, sample_size_i, logging_output = self.task.train_step(
File "/home/assafmushkin/fairseq/fairseq/tasks/fairseq_task.py", line 492, in train_step
loss, sample_size, logging_output = criterion(model, sample)
File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(input, kwargs)
File "/home/assafmushkin/fairseq/fairseq/criterions/ctc.py", line 110, in forward
net_output = model(sample["net_input"])
File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(input, kwargs)
File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 209, in forward
x = self.w2v_encoder(kwargs)
File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(input, kwargs)
File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 402, in forward
x = res["x"]
TypeError: tuple indices must be integers or slices, not str
Code sample
Expected behavior
Extract features correctly, allowing for finetuning.
Environment
fairseq Version - master:
PyTorch Version - 1.9.0
OS - Linux, Ubuntu 18.04.5
How you installed fairseq - source
Python version: 3.9.6
CUDA/cuDNN version: 11.2
GPU models and configuration:
Any other relevant information:
Additional context
It seems the problem is with the expected types returned from the HuBERT extract_features and the Wav2Vec extract_features. The first returns a tuple (the first element of which is a dictionary) and the second returns a dictionary.
This causes a type error in the forward of wav2vec2_asr which expects the wav2vec2 types.
🐛 Bug
After HuBERT model is loaded, cannot fine tune using the base_10h config. Training fails with "TypeError: tuple indices must be integers or slices, not str".
To Reproduce
Steps to reproduce the behavior (always include the command you ran):
Without the unrelated flags, the command is: fairseq-hydra-train --config-dir examples/wav2vec/config/finetuning --config-name base_10h task.data=MY_DATA_DIR model.w2v_path=MY_MODEL_PATH
For the record, my exact command was:
fairseq-hydra-train --config-dir examples/wav2vec/config/finetuning --config-name base_10h task.data=MY_DATA_DIR common.tensorboard_logdir=tb model.w2v_path=MY_MODEL_PATH distributed_training.distributed_world_size=1 common.fp16=False model.freeze_finetune_updates=10 dataset.valid_subset=valid
The model used was the base (960ls) model.
stacktrace: [2021-09-12 15:49:06,942][fairseq.trainer][INFO] - begin training epoch 1 [2021-09-12 15:49:06,943][fairseq_cli.train][INFO] - Start iterating over samples Traceback (most recent call last): File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 28, in hydra_main _hydra_main(cfg) File "/home/assafmushkin/fairseq/fairseq_cli/hydra_train.py", line 53, in _hydra_main distributed_utils.call_main(cfg, pre_main, kwargs) File "/home/assafmushkin/fairseq/fairseq/distributed/utils.py", line 369, in call_main main(cfg, kwargs) File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 180, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner return func(*args, kwds) File "/home/assafmushkin/fairseq/fairseq_cli/train.py", line 291, in train log_output = trainer.train_step(samples) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/contextlib.py", line 79, in inner return func(*args, *kwds) File "/home/assafmushkin/fairseq/fairseq/trainer.py", line 761, in train_step loss, sample_size_i, logging_output = self.task.train_step( File "/home/assafmushkin/fairseq/fairseq/tasks/fairseq_task.py", line 492, in train_step loss, sample_size, logging_output = criterion(model, sample) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/criterions/ctc.py", line 110, in forward net_output = model(sample["net_input"]) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 209, in forward x = self.w2v_encoder(kwargs) File "/home/assafmushkin/anaconda3/envs/fairseq/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/assafmushkin/fairseq/fairseq/models/wav2vec/wav2vec2_asr.py", line 402, in forward x = res["x"] TypeError: tuple indices must be integers or slices, not str
Code sample
Expected behavior
Extract features correctly, allowing for finetuning.
Environment
Additional context
It seems the problem is with the expected types returned from the HuBERT extract_features and the Wav2Vec extract_features. The first returns a tuple (the first element of which is a dictionary) and the second returns a dictionary. This causes a type error in the forward of wav2vec2_asr which expects the wav2vec2 types.