myshell-ai / OpenVoice

Instant voice cloning by MIT and MyShell.
https://research.myshell.ai/open-voice
MIT License
29.81k stars 2.93k forks source link

How do I use the weights trained on mello TTS with the converter? #310

Open kaushal-gawri9899 opened 2 months ago

kaushal-gawri9899 commented 2 months ago

Hey, I've fine tuned mello tts for indian accent and a few indian languages. I wanted to use the weights in the tone converter but realized voice_conversion expects the averaged tensor values for source embeddings. How should I post process my weights to get the average? I've tried to use the extractor to get the weights instead of using MelloTTS, although tone color is varied but it's not cloned.

dosubot[bot] commented 2 months ago

Hey @kaushal-gawri9899! I'm here to help you with any bugs, questions, or contributions you have. Let's tackle this together!

To use the weights trained on Mello TTS with the tone converter and post-process your weights to get the averaged tensor values for source embeddings, follow these steps:

  1. Initialize the ToneColorConverter and Load Checkpoints:

    import os
    import torch
    from openvoice.api import ToneColorConverter
    
    ckpt_converter = 'checkpoints/converter'
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    output_dir = 'outputs'
    
    tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
    tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
    
    os.makedirs(output_dir, exist_ok=True)
  2. Obtain Tone Color Embedding for the Target Speaker:

    from openvoice import se_extractor
    
    reference_speaker = 'resources/example_reference.mp3'  # This is the voice you want to clone
    target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, vad=True)
  3. Use MelloTTS as Base Speakers and Generate Speech:

    from openvoice.api import BaseSpeakerTTS
    
    ckpt_base = 'checkpoints/base_speakers/EN'
    base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)
    base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')
    
    source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
    
    texts = [
        "MyShell is a decentralized and comprehensive platform for discovering, creating, and staking AI-native apps.",
        "MyShell es una plataforma descentralizada y completa para descubrir, crear y apostar por aplicaciones nativas de IA.",
        "MyShell est une plateforme décentralisée et complète pour découvrir, créer et miser sur des applications natives d'IA.",
        "MyShell ist eine dezentralisierte und umfassende Plattform zum Entdecken, Erstellen und Staken von KI-nativen Apps.",
        "MyShell è una piattaforma decentralizzata e completa per scoprire, creare e scommettere su app native di intelligenza artificiale.",
        "MyShellは、AIネイティブアプリの発見、作成、およびステーキングのための分散型かつ包括的なプラットフォームです。",
        "MyShell — это децентрализованная и всеобъемлющая платформа для обнаружения, создания и стейкинга AI-ориентированных приложений.",
        "MyShell هي منصة لامركزية وشاملة لاكتشاف وإنشاء ورهان تطبيقات الذكاء الاصطناعي الأصلية.",
        "MyShell是一个去中心化且全面的平台,用于发现、创建和投资AI原生应用程序。",
        "MyShell एक विकेंद्रीकृत और व्यापक मंच है, जो AI-मूल ऐप्स की खोज, सृजन और स्टेकिंग के लिए है।",
        "MyShell é uma plataforma descentralizada e abrangente para descobrir, criar e apostar em aplicativos nativos de IA."
    ]
    
    src_path = f'{output_dir}/tmp.wav'
    
    for i, t in enumerate(texts):
        base_speaker_tts.tts(t, src_path, speaker='default', language='English', speed=1.0)
    
        save_path = f'{output_dir}/output_crosslingual_{i}.wav'
    
        # Run the tone color converter
        encode_message = "@MyShell"
        tone_color_converter.convert(
            audio_src_path=src_path, 
            src_se=source_se, 
            tgt_se=target_se, 
            output_path=save_path,
            message=encode_message)
  4. Post-Process Weights to Get Averaged Tensor Values for Source Embeddings:

    To average the tensor values for source embeddings, you can use the following code snippet:

    import torch
    
    def average_tensors(tensor_list):
        return sum(tensor_list) / len(tensor_list)
    
    source_embeddings_dir = 'checkpoints/base_speakers/ses'
    source_embeddings = []
    
    for file_name in os.listdir(source_embeddings_dir):
        if file_name.endswith('.pth'):
            tensor = torch.load(os.path.join(source_embeddings_dir, file_name), map_location=device)
            source_embeddings.append(tensor)
    
    averaged_tensor = average_tensors(source_embeddings)
    torch.save(averaged_tensor, 'checkpoints/averaged_source_embedding.pth')

This code will load all the source embeddings, compute their average, and save the averaged tensor [1][2][3].

To continue talking to Dosu, mention @dosu.