RVC-Boss / GPT-SoVITS

1 min voice data can also be used to train a good TTS model! (few shot voice cloning)
MIT License
36.26k stars 4.14k forks source link

SSL特征提取,爆内存问题 #1803

Open kemingc-cmu-F24 opened 1 week ago

kemingc-cmu-F24 commented 1 week ago

修复: D:\GPT-SoVITS-v2-240821\GPT_SoVITS\prepare_datasets\2-get-hubert-wav32k.py 替换以下函数:

def name2go(wav_name, wav_path):
    hubert_path = f"{hubert_dir}/{wav_name}.pt"
    if os.path.exists(hubert_path):
        return
    tmp_audio = load_audio(wav_path, 32000)
    tmp_max = np.abs(tmp_audio).max()
    if tmp_max > 2.2:
        print(f"{wav_name}-filtered, {tmp_max}")
        return
    tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
    tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
    tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000)

    tensor_wav16 = torch.from_numpy(tmp_audio).to(device)
    if is_half:
        tensor_wav16 = tensor_wav16.half()

    try:
        with torch.no_grad():
            ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu()
        if torch.isnan(ssl).any():
            nan_fails.append((wav_name, wav_path))
            print(f"nan filtered: {wav_name}")
            return
        wavfile.write(f"{wav32dir}/{wav_name}", 32000, tmp_audio32.astype("int16"))
        my_save(ssl, hubert_path)
    except Exception as e:
        print(f"Error processing {wav_name}: {e}")
    finally:
        del tensor_wav16, ssl
        torch.cuda.empty_cache()
        gc.collect()
RVC-Boss commented 6 days ago

不能每个hubert特征提取都过torch.cuda.empty_cache(),这个函数进for循环会拉低整体速度 看上去是数据nan异常return的话会泄露,那torch.cuda.empty_cache()前要判一下