JarodMica / ai-voice-cloning

GNU General Public License v3.0
656 stars 144 forks source link

Correct way to use the pretrained models without the API. #49

Open 0xrushi opened 9 months ago

0xrushi commented 9 months ago

Hello, I'm developing a script to directly load models into a TTS system, rather than utilizing an API server for this purpose. I'm interested in finding out whether this method is appropriate.

import torch
import sys 
sys.path.append('./src')
from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate

autoregressive_model_path = './training/tina/finetune/models/4020_gpt.pth'
diffusion_model_path = './models/tortoise/diffusion_decoder.pth'
vocoder_model_path = 'bigvgan_24khz_100band.pth'  # Ensure this path is correct
tokenizer_json_path = './modules/tortoise-tts/tortoise/data/tokenizer.json'

tts = TorToise_TTS(
    autoregressive_model_path=autoregressive_model_path,
    diffusion_model_path=diffusion_model_path,
    vocoder_model=vocoder_model_path,
    tokenizer_json=tokenizer_json_path
)

input_text = "Hello, this is a test of the text-to-speech system."
audio_output = tts.tts(text=input_text,
                       num_autoregressive_samples=16,
                       temperature=0.2,
                       length_penalty=1,
                       repetition_penalty=2.0,
                       top_p=0.8,
                       max_mel_tokens=500,
                       cond_free=True,
                       cond_free_k=2,
                       diffusion_temperature=1.0,
                       diffusion_sampler="DDIM",
                       half_p=False)
if audio_output.ndim == 3:
    audio_output = audio_output.squeeze(0)

import torchaudio
torchaudio.save("output_audio.wav", audio_output.cpu(), sample_rate=24000)  # Adjust the sample rate if necessary

I had to add strict=False in self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path), strict=False) src\tortoise\api.py

due to the error was giving (shown below)

Loading autoregressive model: ./training/me/finetune/models/4020_gpt.pth
Traceback (most recent call last):
  File "D:\Personal\Workspace\ai-voice-cloning\hg.py", line 21, in <module>
    tts = TorToise_TTS(
  File "D:\Personal\Workspace\ai-voice-cloning\./modules/tortoise-tts\tortoise\api.py", line 308, in __init__
    self.load_autoregressive_model(autoregressive_model_path)
  File "D:\Personal\Workspace\ai-voice-cloning\./modules/tortoise-tts\tortoise\api.py", line 391, in load_autoregressive_model
    self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
  File "C:\Users\administrator\anaconda3\envs\voice\lib\site-packages\torch\nn\modules\module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UnifiedVoice:
        Unexpected key(s) in state_dict: "gpt.h.0.attn.bias", "gpt.h.0.attn.masked_bias", "gpt.h.1.attn.bias", "gpt.h.1.attn.masked_bias", "gpt.h.2.attn.bias", "gpt.h.2.attn.masked_bias", "gpt.h.3.attn.bias", "gpt.h.3.attn.masked_bias", "gpt.h.4.attn.bias", "gpt.h.4.attn.masked_bias", "gpt.h.5.attn.bias", "gpt.h.5.attn.masked_bias", "gpt.h.6.attn.bias", "gpt.h.6.attn.masked_bias", "gpt.h.7.attn.bias", "gpt.h.7.attn.masked_bias", "gpt.h.8.attn.bias", "gpt.h.8.attn.masked_bias", "gpt.h.9.attn.bias", "gpt.h.9.attn.masked_bias", "gpt.h.10.attn.bias", "gpt.h.10.attn.masked_bias", "gpt.h.11.attn.bias", "gpt.h.11.attn.masked_bias", "gpt.h.12.attn.bias", "gpt.h.12.attn.masked_bias", "gpt.h.13.attn.bias", "gpt.h.13.attn.masked_bias", "gpt.h.14.attn.bias", "gpt.h.14.attn.masked_bias", "gpt.h.15.attn.bias", "gpt.h.15.attn.masked_bias", "gpt.h.16.attn.bias", "gpt.h.16.attn.masked_bias", "gpt.h.17.attn.bias", "gpt.h.17.attn.masked_bias", "gpt.h.18.attn.bias", "gpt.h.18.attn.masked_bias", "gpt.h.19.attn.bias", "gpt.h.19.attn.masked_bias", "gpt.h.20.attn.bias", "gpt.h.20.attn.masked_bias", "gpt.h.21.attn.bias", "gpt.h.21.attn.masked_bias", "gpt.h.22.attn.bias", "gpt.h.22.attn.masked_bias", "gpt.h.23.attn.bias", "gpt.h.23.attn.masked_bias", "gpt.h.24.attn.bias", "gpt.h.24.attn.masked_bias", "gpt.h.25.attn.bias", "gpt.h.25.attn.masked_bias", "gpt.h.26.attn.bias", "gpt.h.26.attn.masked_bias", "gpt.h.27.attn.bias", "gpt.h.27.attn.masked_bias", "gpt.h.28.attn.bias", "gpt.h.28.attn.masked_bias", "gpt.h.29.attn.bias", "gpt.h.29.attn.masked_bias".
(voice) PS D:\Personal\Workspace\ai-voice-cloning>

Is there any way to do it without doing strict=False