facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

Really slow training in pretraining phase using wav2vec2 while memory is consumed #4574

Open Abdullah955 opened 2 years ago

Abdullah955 commented 2 years ago

i have 3* A100 40G GPU and i'm trying to train wav2vec2 the pretraining model the GPU memory is consumed and the utilization is really low

i've tried to increase the max-token till i get out of memory error with no luck. and num_workers & normalize have no effect.

Screen Shot 2022-07-13 at 11 52 45 PM

the problem is that using 1GPU is much higher than using 3

config file:

common:
  fp16: true
  log_format: json
  log_interval: 200

checkpoint:
  save_interval_updates: 25000
  keep_interval_updates: 1
  no_epoch_checkpoints: true

task:
  _name: audio_pretraining
  data: ???
  max_sample_size: 250000
  min_sample_size: 32000
  normalize: false

dataset:
  num_workers: 6
  max_tokens: 3000000
  skip_invalid_size_inputs_valid_test: true

distributed_training:
  distributed_world_size: 3
  ddp_backend: legacy_ddp

criterion:
  _name: wav2vec
  infonce: true
  log_keys: ["prob_perplexity","code_perplexity","temp"]
  loss_weights: [0.1, 10]

optimization:
  max_update: 400000
  lr: [0.0005]

optimizer:
  _name: adam
  adam_betas: (0.9,0.98)
  adam_eps: 1e-06
  weight_decay: 0.01

lr_scheduler:
  _name: polynomial_decay
  warmup_updates: 32000

model:
  _name: wav2vec2
  quantize_targets: true
  final_dim: 256
  encoder_layerdrop: 0.05
  dropout_input: 0.1
  dropout_features: 0.1
  feature_grad_mult: 0.1
  encoder_embed_dim: 768

command used fairseq-hydra-train task.data=/path/to/data common.tensorboard_logdir=tslog/ --config-dir examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech

the strange thing is that using the same configration in data2vec results 100% of GPU utilization

What's your environment?

gmryu commented 2 years ago

Sorry, I know nothing.

If you have nothing to do with it, would you mind changing distributed_training: ddp_backend: legacy_ddp to something else? If it happens to help you please also share this finding here. Sincere regards.