facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.17k stars 6.37k forks source link

AssertionError: Could not infer task type #3683

Closed soroushhashemifar closed 2 years ago

soroushhashemifar commented 3 years ago

I'm training unsupervised wav2vec on Vietnamese language. Everything runs perfect until the GAN. When I run the following command:

PREFIX=w2v_unsup_gan_xp
TASK_DATA=./output_dir/precompute_pca512_cls128_mean_pooled
TEXT_DATA=./output_dir/phones  # path to fairseq-preprocessed GAN data (phones dir)
KENLM_PATH=./output_dir/phones/lm.phones.filtered.04.bin  # KenLM 4-gram phoneme language model (LM data = GAN data here)

PYTHONPATH=$FAIRSEQ_ROOT PREFIX=$PREFIX fairseq-hydra-train -m --config-dir ./fairseq/examples/wav2vec/unsupervised/config/gan --config-name w2vu task.data=${TASK_DATA} task.text_data=${TEXT_DATA} task.kenlm_path=${KENLM_PATH}

I get the following error:

[2021-07-02 15:16:45,388][HYDRA] Launching 1 jobs locally
[2021-07-02 15:16:45,389][HYDRA]    #0 : task.data=/content/output_dir/precompute_pca512_cls128_mean_pooled task.text_data=/content/output_dir/phones task.kenlm_path=/content/output_dir/phones/lm.phones.filtered.04.bin
[2021-07-02 15:16:49,198][fairseq_cli.train][INFO] - {'common': {'fp16': False, 'fp16_no_flatten_grads': True, 'log_format': 'json', 'log_interval': 100, 'tensorboard_logdir': 'tb', 'reset_logging': False, 'suppress_crashes': False, 'profile': True, 'tpu': False, 'log_file': None, 'seed': 4}, 'checkpoint': {'save_interval': 1000, 'save_interval_updates': 1000, 'no_epoch_checkpoints': True, 'best_checkpoint_metric': 'weighted_lm_ppl', 'save_dir': '.', 'write_checkpoints_asynchronously': False}, 'distributed_training': {'distributed_world_size': 1, 'distributed_init_method': None, 'tpu': False, 'pipeline_model_parallel': False, 'distributed_port': 0, 'distributed_no_spawn': True, 'distributed_rank': 0}, 'task': {'_name': 'unpaired_audio_text', 'data': '/content/output_dir/precompute_pca512_cls128_mean_pooled', 'text_data': '/content/output_dir/phones', 'labels': 'phn', 'sort_by_length': False, 'unfiltered': False, 'max_length': None, 'append_eos': False, 'kenlm_path': '/content/output_dir/phones/lm.phones.filtered.04.bin'}, 'dataset': {'num_workers': 6, 'batch_size': 160, 'skip_invalid_size_inputs_valid_test': True, 'valid_subset': 'valid', 'validate_interval': 1000, 'validate_interval_updates': 1000, 'max_tokens': 1000}, 'criterion': {'_name': 'model', 'loss_weights': {}, 'log_keys': ['accuracy_dense', 'accuracy_token', 'temp', 'code_ppl']}, 'optimization': {'max_update': 150000, 'clip_norm': 5.0, 'lr': [0]}, 'optimizer': {'_name': 'composite', 'groups': {'generator': {'lr': [0.0004], 'lr_float': None, 'optimizer': {'_name': 'adam', 'adam_betas': [0.5, 0.98], 'adam_eps': 1e-06, 'weight_decay': 0, 'amsgrad': False}, 'lr_scheduler': {'_name': 'fixed', 'warmup_updates': 0}}, 'discriminator': {'lr': [0.0005], 'lr_float': None, 'optimizer': {'_name': 'adam', 'adam_betas': [0.5, 0.98], 'adam_eps': 1e-06, 'weight_decay': 0.0001, 'amsgrad': False}, 'lr_scheduler': {'_name': 'fixed', 'warmup_updates': 0}}}}, 'lr_scheduler': {'_name': 'pass_through'}, 'model': {'_name': 'wav2vec_u', 'discriminator_dim': 384, 'discriminator_depth': 2, 'discriminator_kernel': 6, 'discriminator_linear_emb': False, 'discriminator_causal': True, 'discriminator_max_pool': False, 'discriminator_act_after_linear': False, 'discriminator_dropout': 0.0, 'discriminator_weight_norm': False, 'generator_stride': 1, 'generator_kernel': 4, 'generator_bias': False, 'generator_dropout': 0.1, 'smoothness_weight': 0.5, 'smoothing': 0, 'smoothing_one_sided': False, 'gumbel': False, 'hard_gumbel': False, 'gradient_penalty': 1.5, 'code_penalty': 4.0, 'temp': [2, 0.1, 0.99995], 'input_dim': 512, 'segmentation': {'type': 'JOIN', 'mean_pool_join': False, 'remove_zeros': False}}, 'job_logging_cfg': {'version': 1, 'formatters': {'simple': {'format': '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'}}, 'handlers': {'console': {'class': 'logging.StreamHandler', 'formatter': 'simple', 'stream': 'ext://sys.stdout'}, 'file': {'class': 'logging.FileHandler', 'formatter': 'simple', 'filename': 'hydra_train.log'}}, 'root': {'level': 'INFO', 'handlers': ['console', 'file']}, 'disable_existing_loggers': False}}
Traceback (most recent call last):
  File "/content/fairseq/fairseq_cli/hydra_train.py", line 43, in hydra_main
    distributed_utils.call_main(cfg, pre_main)
  File "/content/fairseq/fairseq/distributed/utils.py", line 369, in call_main
    main(cfg, **kwargs)
  File "/content/fairseq/fairseq_cli/train.py", line 88, in main
    task = tasks.setup_task(cfg.task)
  File "/content/fairseq/fairseq/tasks/__init__.py", line 44, in setup_task
    ), f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}"
AssertionError: Could not infer task type from {'_name': 'unpaired_audio_text', 'data': '/content/output_dir/precompute_pca512_cls128_mean_pooled', 'text_data': '/content/output_dir/phones', 'labels': 'phn', 'sort_by_length': False, 'unfiltered': False, 'max_length': None, 'append_eos': False, 'kenlm_path': '/content/output_dir/phones/lm.phones.filtered.04.bin'}. Available tasks: dict_keys(['translation_multi_simple_epoch', 'audio_pretraining', 'denoising', 'speech_to_text', 'translation', 'hubert_pretraining', 'simul_speech_to_text', 'simul_text_to_text', 'legacy_masked_lm', 'online_backtranslation', 'multilingual_denoising', 'translation_lev', 'multilingual_translation', 'translation_from_pretrained_bart', 'language_modeling', 'cross_lingual_lm', 'masked_lm', 'sentence_ranking', 'sentence_prediction', 'semisupervised_translation', 'translation_from_pretrained_xlm', 'multilingual_masked_lm', 'dummy_lm', 'dummy_masked_lm', 'dummy_mt'])

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

my wavu.yaml contains the following content:

# @package _group_

common:
  fp16: false
  fp16_no_flatten_grads: true
  log_format: json
  log_interval: 100
  tensorboard_logdir: tb
  reset_logging: false
  suppress_crashes: false
  profile: true
  tpu: false
  log_file: null
  seed: 4

checkpoint:
  save_interval: 1000
  save_interval_updates: 1000
  no_epoch_checkpoints: true
  best_checkpoint_metric: weighted_lm_ppl
  save_dir: .
  write_checkpoints_asynchronously: false

distributed_training:
  distributed_world_size: 1
  distributed_init_method: null
  tpu: false
  pipeline_model_parallel: false
  distributed_port: 0
  distributed_no_spawn: true
  distributed_rank: 0

task:
  _name: unpaired_audio_text
  data: ???
  text_data: ???
  labels: phn
  sort_by_length: false
  unfiltered: false
  max_length: null
  append_eos: false
  kenlm_path: ???

dataset:
  num_workers: 6
  batch_size: 160
  skip_invalid_size_inputs_valid_test: true
  valid_subset: valid
  validate_interval: 1000
  validate_interval_updates: 1000
  max_tokens: 1000

criterion:
  _name: model
  log_keys:
    - accuracy_dense
    - accuracy_token
    - temp
    - code_ppl

optimization:
  max_update: 150000
  clip_norm: 5.0
  lr: [0]

optimizer:
  _name: composite
  groups:
    generator:
      lr: [0.0004]
      lr_float: null
      optimizer:
        _name: adam
        adam_betas: [0.5,0.98]
        adam_eps: 1e-06
        weight_decay: 0
        amsgrad: false
      lr_scheduler:
        _name: fixed
        warmup_updates: 0
    discriminator:
      lr: [ 0.0005 ]
      lr_float: null
      optimizer:
        _name: adam
        adam_betas: [0.5,0.98]
        adam_eps: 1e-06
        weight_decay: 0.0001
        amsgrad: false
      lr_scheduler:
        _name: fixed
        warmup_updates: 0

lr_scheduler: pass_through

model:
  _name: wav2vec_u

  discriminator_dim: 384
  discriminator_depth: 2
  discriminator_kernel: 6
  discriminator_linear_emb: false
  discriminator_causal: true
  discriminator_max_pool: false
  discriminator_act_after_linear: false
  discriminator_dropout: 0.0
  discriminator_weight_norm: false

  generator_stride: 1
  generator_kernel: 4
  generator_bias: false
  generator_dropout: 0.1

  smoothness_weight: 0.5
  smoothing: 0
  smoothing_one_sided: false
  gumbel: false
  hard_gumbel: false
  gradient_penalty: 1.5
  code_penalty: 4.0
  temp: [ 2,0.1,0.99995 ]
  input_dim: 512

  segmentation:
    type: JOIN
    mean_pool_join: false
    remove_zeros: false
marcinkusz commented 3 years ago

I don't know what causes it exactly but it seems to have something to do with 9bee82e4a7b73249a88f2e2d286e991493ef13c2. I temporarily fixed it by temporarily reverting the lines added by the commit, let me know if that works for you too.

soroushhashemifar commented 3 years ago

It didn't solve the issue. It's weird to modify the w2vu.yaml file manually to fix the errors!

mjurkus commented 3 years ago

Facing the same issue. Any workaround for that?

mjurkus commented 3 years ago

Temporarily worked around by moving task and data to fairseq/data and fairseq/tasks from wav2ver/unsupervised.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

stale[bot] commented 2 years ago

Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!