jaywalnut310 / vits

VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
https://jaywalnut310.github.io/vits-demo/index.html
MIT License
6.84k stars 1.26k forks source link

Training on GTX2080 #39

Open liuhaogeng opened 2 years ago

liuhaogeng commented 2 years ago

Process 2 terminated with the following error: Traceback (most recent call last): File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap fn(i, *args) File "/data3/liuhaogeng/test/vits-main/train.py", line 120, in run train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) File "/data3/liuhaogeng/test/vits-main/train.py", line 138, in train_and_evaluate for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(train_loader): File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in next data = self._next_data() File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 989, in _next_data return self._process_data(data) File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1014, in _process_data data.reraise() File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/_utils.py", line 395, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in DataLoader worker process 2. Original Traceback (most recent call last): File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop data = fetcher.fetch(index) File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/data3/liuhaogeng/test/vits-main/data_utils.py", line 94, in getitem return self.get_audio_text_pair(self.audiopaths_and_text[index]) File "/data3/liuhaogeng/test/vits-main/data_utils.py", line 62, in get_audio_text_pair spec, wav = self.get_audio(audiopath) File "/data3/liuhaogeng/test/vits-main/data_utils.py", line 74, in get_audio spec = torch.load(spec_filename) File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/serialization.py", line 577, in load with _open_zipfile_reader(opened_file) as opened_zipfile: File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/serialization.py", line 241, in init super(_open_zipfile_reader, self).init(torch._C.PyTorchFileReader(name_or_buffer)) RuntimeError: [enforce fail at inline_container.cc:144] . PytorchStreamReader failed reading zip archive: failed finding central directory

I set the batch size to 16

lpdink commented 2 years ago

I met the same error. This error is caused because that function TextAudioLoader.get_audio saved empty spectrograms file(.spec.pt). (Maybe caused by multi-processing?) Torch will throw this error when load empty file. You can compute and save spectrograms before your training. The following is spectrogram compute code separated from vits. Change base to your wavs's folder path.

from scipy.io.wavfile import read
import torch
import numpy as np
import os
from multiprocessing import Pool
from tqdm import tqdm

# Change here
base="Your wavs's folder path"

hann_window = {}
def load_wav_to_torch(full_path):
    sampling_rate, data = read(full_path)
    # data, sampling_rate = librosa.load(full_path)
    return torch.FloatTensor(data.astype(np.float32)), sampling_rate

def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global hann_window
    dtype_device = str(y.dtype) + '_' + str(y.device)
    wnsize_dtype_device = str(win_size) + '_' + dtype_device
    if wnsize_dtype_device not in hann_window:
        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)

    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)

    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
    return spec

def get_audio(filename):
    max_wave_length = 32768.0
    filter_length = 1024
    hop_length = 256
    win_length = 1024
    audio, sampling_rate = load_wav_to_torch(filename)
    audio_norm = audio / max_wave_length
    audio_norm = audio_norm.unsqueeze(0)
    spec_filename = filename.replace(".wav", ".spec.pt")
    spec = spectrogram_torch(audio_norm, filter_length,
        sampling_rate, hop_length, win_length,
        center=False)
    spec = torch.squeeze(spec, 0)
    torch.save(spec, spec_filename)

if __name__=="__main__":
    waves = []
    for wav_name in os.listdir(base):
        wav_path = os.path.join(base, wav_name)
        if wav_path.endswith(".wav"):
            waves.append(wav_path)
    with Pool(16) as p:
        print(list((tqdm(p.imap(get_audio,waves),total=len(waves)))))
JoanisTriandafilidi commented 2 years ago

@lpdink, Thanks, I had this error too, and your advice helped me a lot!