yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.94k stars 416 forks source link

After training 1 epoch, train_first.py crashes: RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [1, 1, 1, 800] #273

Open fungus75 opened 3 months ago

fungus75 commented 3 months ago

I tried out to train with the default settings on a 3090 (LJ-Speech dataset from https://keithito.com/LJ-Speech-Dataset/,... all like shown in the readme)

I only had to adjust config.yml to fit into my 3090: batch_size: 16 max_len: 150

Others defaut.

After 1 epoch the script crashes (maybe during validation?)

Traceback (most recent call last):
  File "train_first.py", line 445, in <module>
    main()
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "train_first.py", line 407, in main
    y_rec = model.decoder(en, F0_real, real_norm, s)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/ai/StyleTTS2/Modules/istftnet.py", line 511, in forward
    F0 = self.F0_conv(F0_curve.unsqueeze(1))
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 310, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/ai/miniconda3/envs/StyleTTS/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [1, 1, 1, 800]
fungus75 commented 3 months ago

My Config.yml:

log_dir: "Models/LJSpeech"
first_stage_path: "first_stage.pth"
save_freq: 2
log_interval: 10
device: "cuda"
epochs_1st: 200 # number of epochs for first stage training (pre-training)
epochs_2nd: 100 # number of peochs for second stage training (joint training)
batch_size: 16
max_len: 150 # maximum number of frames
pretrained_model: ""
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters

F0_path: "Utils/JDC/bst.t7"
ASR_config: "Utils/ASR/config.yml"
ASR_path: "Utils/ASR/epoch_00080.pth"
PLBERT_dir: 'Utils/PLBERT/'

data_params:
  train_data: "Data/train_list.txt"
  val_data: "Data/val_list.txt"
  root_path: "Data/LJSpeech-1.1/wavs"
  OOD_data: "Data/OOD_texts.txt"
  min_length: 50 # sample until texts with this size are obtained for OOD texts

preprocess_params:
  sr: 24000
  spect_params:
    n_fft: 2048
    win_length: 1200
    hop_length: 300

model_params:
  multispeaker: false

  dim_in: 64 
  hidden_dim: 512
  max_conv_dim: 512
  n_layer: 3
  n_mels: 80

  n_token: 178 # number of phoneme tokens
  max_dur: 50 # maximum duration of a single phoneme
  style_dim: 128 # style vector size

  dropout: 0.2

  # config for decoder
  decoder: 
      type: 'istftnet' # either hifigan or istftnet
      resblock_kernel_sizes: [3,7,11]
      upsample_rates :  [10, 6]
      upsample_initial_channel: 512
      resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
      upsample_kernel_sizes: [20, 12]
      gen_istft_n_fft: 20
      gen_istft_hop_size: 5

  # speech language model config
  slm:
      model: 'microsoft/wavlm-base-plus'
      sr: 16000 # sampling rate of SLM
      hidden: 768 # hidden size of SLM
      nlayers: 13 # number of layers of SLM
      initial_channel: 64 # initial channels of SLM discriminator head

  # style diffusion model config
  diffusion:
    embedding_mask_proba: 0.1
    # transformer config
    transformer:
      num_layers: 3
      num_heads: 8
      head_features: 64
      multiplier: 2

    # diffusion distribution config
    dist:
      sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
      estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
      mean: -3.0
      std: 1.0

loss_params:
    lambda_mel: 5. # mel reconstruction loss
    lambda_gen: 1. # generator loss
    lambda_slm: 1. # slm feature matching loss

    lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
    lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
    TMA_epoch: 50 # TMA starting epoch (1st stage)

    lambda_F0: 1. # F0 reconstruction loss (2nd stage)
    lambda_norm: 1. # norm reconstruction loss (2nd stage)
    lambda_dur: 1. # duration loss (2nd stage)
    lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
    lambda_sty: 1. # style reconstruction loss (2nd stage)
    lambda_diff: 1. # score matching loss (2nd stage)

    diff_epoch: 20 # style diffusion starting epoch (2nd stage)
    joint_epoch: 50 # joint training starting epoch (2nd stage)

optimizer_params:
  lr: 0.0001 # general learning rate
  bert_lr: 0.00001 # learning rate for PLBERT
  ft_lr: 0.00001 # learning rate for acoustic modules

slmadv_params:
  min_len: 400 # minimum length of samples
  max_len: 500 # maximum length of samples
  batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
  iter: 10 # update the discriminator every this iterations of generator update
  thresh: 5 # gradient norm above which the gradient is scaled
  scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
  sig: 1.5 # sigma for differentiable duration modeling