wenet-e2e / wetts

Production First and Production Ready End-to-End Text-to-Speech Toolkit
Apache License 2.0
368 stars 58 forks source link

[fix] Update train.py #217

Closed lsrami closed 4 months ago

lsrami commented 4 months ago

训练过程中,某些设备在进行torch.stft时不支持half类型而报错,转换类型为float解决这个问题

  File "/public/home/user/wetts/examples/french-v1/vits/train.py", line 323, in main
    train_and_evaluate(
  File "/public/home/user/wetts/examples/french-v1/vits/train.py", line 429, in train_and_evaluate
    y_hat_mel = mel_spectrogram_torch(
  File "/public/home/user/wetts/wetts/vits/utils/mel_processing.py", line 167, in mel_spectrogram_torch
    spec = torch.stft(
  File "/public/home/user/miniconda3/envs/wetts/lib/python3.10/site-packages/torch/functional.py", line 632, in stft
    spec = torch.stft(
  File "/public/home/user/miniconda3/envs/wetts/lib/python3.10/site-packages/torch/functional.py", line 632, in stft
    spec = torch.stft(
    spec = torch.stft(
  File "/public/home/user/miniconda3/envs/wetts/lib/python3.10/site-packages/torch/functional.py", line 632, in stft
  File "/public/home/user/miniconda3/envs/wetts/lib/python3.10/site-packages/torch/functional.py", line 632, in stft
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
RuntimeError: hipFFT doesn't support transforms of type: Half    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
lsrami commented 4 months ago

在配置文件中指定"fp16_run": true,即使用fp16训练时才会出现