JarodMica / ai-voice-cloning

GNU General Public License v3.0
655 stars 144 forks source link

Better handling of hifigan .pths and non hifigan .pths #118

Open StillTravelling opened 6 months ago

StillTravelling commented 6 months ago

Sorry I'm hopeless with git but I think the below changes will help when switching between hifigan and non hifigan .pths....

Modules\Tortoise-tts\tortoise\utils\audio.py

def load_voice(voice, extra_voice_dirs=[], load_latents=True, sample_rate=22050, device='cpu', model_hash=None, use_hifigan=False):
    if voice == 'random':
        return None, None
    print(f"hifigan = {use_hifigan}, voice={voice}")
    voices = _get_voices(dirs=[get_voice_dir()] + extra_voice_dirs, load_latents=load_latents)

    paths = voices[voice]
    mtime = 0

    latent = None
    voices = []

    for path in paths:
        filename = os.path.basename(path)
        if filename[-4:] == ".pth" and use_hifigan == False and filename[:12] == "cond_latents":
            if not model_hash and filename == "cond_latents.pth":
                latent = path
            elif model_hash and filename == f"cond_latents_{model_hash[:8]}.pth":
                latent = path
        elif filename[-4:] == ".pth" and use_hifigan == True and filename[:20] == "hifigan_cond_latents":
            if not model_hash and filename == "hifigan_cond_latents.pth":
                latent = path
            elif model_hash and filename == f"hifigan_cond_latents_{model_hash[:8]}.pth":
                latent = path
        else:
            voices.append(path)
            mtime = max(mtime, os.path.getmtime(path))

    if load_latents and latent is not None:
        #if os.path.getmtime(latent) > mtime:
            print(f"Reading from latent: {latent}")
            return None, torch.load(latent, map_location=device)
        #print(f"Latent file out of date: {latent}")

    samples = []
    for path in voices:
        c = load_audio(path, sample_rate)
        samples.append(c)
    return samples, None

src\utils.py

def fetch_voice( voice ):
        cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
        if cache_key in voice_cache:
            return voice_cache[cache_key]

        print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
        sample_voice = None
        if voice == "microphone":
            if parameters['mic_audio'] is None:
                raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
            voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None
        elif voice == "random":
            voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
        else:
            if progress is not None:
                notify_progress(f"Loading voice: {voice}", progress=progress)

            voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash, use_hifigan=args.use_hifigan)

        if voice_samples and len(voice_samples) > 0:
            if conditioning_latents is None:
                conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks'])

            sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
            voice_samples = None

        voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
        return voice_cache[cache_key]

###     
    def get_info( voice, settings = None, latents = True ):
        info = {}
        info.update(parameters)

        info['time'] = time.time()-full_start_time
        info['datetime'] = datetime.now().isoformat()

        info['model'] = tts.autoregressive_model_path
        info['model_hash'] = tts.autoregressive_model_hash 

        info['progress'] = None
        del info['progress']

        if info['delimiter'] == "\n":
            info['delimiter'] = "\\n"

        if settings is not None:
            for k in settings:
                if k in info:
                    info[k] = settings[k]

            if 'half_p' in settings and 'cond_free' in settings:
                info['experimentals'] = []
                if settings['half_p']:
                    info['experimentals'].append("Half Precision")
                if settings['cond_free']:
                    info['experimentals'].append("Conditioning-Free")

        if latents and "latents" not in info:
            voice = info['voice']
            model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8]

            dir = f'{get_voice_dir()}/{voice}/'
            if args.use_hifigan:
                latents_path = f'{dir}/cond_latents_{model_hash}.pth'
            else:
                latents_path = f'{dir}/hifigan_cond_latents_{model_hash}.pth'

            if voice == "random" or voice == "microphone":
                if args.use_hifigan:
                    if latents and settings is not None and torch.any(settings['conditioning_latents']):
                        os.makedirs(dir, exist_ok=True)
                        torch.save(conditioning_latents, latents_path)
                else: 
                    if latents and settings is not None and settings['conditioning_latents']:
                        os.makedirs(dir, exist_ok=True)
                        torch.save(conditioning_latents, latents_path)

            if latents_path and os.path.exists(latents_path):
                try:
                    with open(latents_path, 'rb') as f:
                        info['latents'] = base64.b64encode(f.read()).decode("ascii")
                except Exception as e:
                    pass

        return info     

###
settings = get_settings( override=override )
        #print(settings) #This line changed to comment out
        try:
            if args.use_hifigan:
                gen = tts.tts(cut_text, **settings)
            else:
                gen, additionals = tts.tts(cut_text, **settings )
                parameters['seed'] = additionals[0]
        except Exception as e:
            raise RuntimeError(f'Possible latent mismatch: click the "(Re)Compute Voice Latents" button and then try again. Error: {e}')

###
def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, original_ar=False, original_diffusion=False):
    global tts
    global args

    unload_whisper()
    unload_voicefixer()

    if not tts:
        if tts_loading:
            raise Exception("TTS is still initializing...")
        load_tts()

    if hasattr(tts, "loading") and tts.loading:
        raise Exception("TTS is still initializing...")

    if args.tts_backend == "bark":
        tts.create_voice( voice )
        return

    if args.autoregressive_model == "auto":
        tts.load_autoregressive_model(deduce_autoregressive_model(voice))

    if voice:
        load_from_dataset = voice_latents_chunks == 0

        if load_from_dataset:
            dataset_path = f'./training/{voice}/train.txt'
            if not os.path.exists(dataset_path):
                load_from_dataset = False
            else:
                with open(dataset_path, 'r', encoding="utf-8") as f:
                    lines = f.readlines()

                print("Leveraging dataset for computing latents")

                voice_samples = []
                max_length = 0
                for line in lines:
                    filename = f'./training/{voice}/{line.split("|")[0]}'

                    waveform = load_audio(filename, 22050)
                    max_length = max(max_length, waveform.shape[-1])
                    voice_samples.append(waveform)

                for i in range(len(voice_samples)):
                    voice_samples[i] = pad_or_truncate(voice_samples[i], max_length)

                voice_latents_chunks = len(voice_samples)
                if voice_latents_chunks == 0:
                    print("Dataset is empty!")
                    load_from_dataset = True
        if not load_from_dataset:
            voice_samples, _ = load_voice(voice, load_latents=False, use_hifigan=args.use_hifigan) #This line changed

    if voice_samples is None:
        return

    conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, original_ar=original_ar, original_diffusion=original_diffusion)

    if len(conditioning_latents) == 4:
        conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
    if args.use_hifigan: #newsection
        outfile = f'{get_voice_dir()}/{voice}/hifigan_cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
    else:
        outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' #end newsection
    torch.save(conditioning_latents, outfile)
    print(f'Saved voice latents: {outfile}')

    return conditioning_latents

###
def reload_tts():
    unload_tts()
    load_tts()

def change_hifigan(newvalue=True): #newsection
    args.use_hifigan=newvalue
    save_args_settings()
    do_gc()
    reload_tts()
    return args.use_hifigan

def get_hifigan():
    return args.use_hifigan  #endnewsection

src\webui.py

EXEC_SETTINGS['autoregressive_model'].change(
                    fn=update_autoregressive_model,
                    inputs=EXEC_SETTINGS['autoregressive_model'],
                    outputs=None,
                    api_name="set_autoregressive_model"
                )

                EXEC_SETTINGS['use_hifigan'].change( #newsection
                    fn=change_hifigan,
                    inputs=EXEC_SETTINGS['use_hifigan'],
                    outputs=EXEC_SETTINGS['use_hifigan'],
                    api_name="use_hifigan"
                )

                EXEC_SETTINGS['use_hifigan'].select(
                    fn=get_hifigan,
                    outputs=EXEC_SETTINGS['use_hifigan'],
                    api_name="get_hifigan"
                ) #endnewsection