coqui-ai / Trainer

🐸 - A general purpose model trainer, as flexible as it gets
196 stars 116 forks source link

[Bug] Outputs referenced before assignment error while using YourTTS recipe #96

Closed Ca-ressemble-a-du-fake closed 1 year ago

Ca-ressemble-a-du-fake commented 1 year ago

Describe the bug

Hi,

When running YourTTS recipe with my own LJSpeech dataset, during the first evaluation I get the following error :

> DataLoader initialization
| > Tokenizer:
        | > add_blank: True
        | > use_eos_bos: False
        | > use_phonemes: False
| > Number of instances : 20
 | > Preprocessing samples
 | > Max text length: 119
 | > Min text length: 21
 | > Avg text length: 45.2
 |
 | > Max audio length: 119952.0
 | > Min audio length: 18103.0
 | > Avg audio length: 48773.45
 | > Num. instances discarded samples: 0
 | > Batch group size: 0.
 > Using weighted sampler for attribute 'speaker_name' with alpha '1.0'
None
 > Attribute weights for '['ljspeech']'
 | > [0.22360679774997894]

 > EVALUATION 

 ! Run is kept in /home/caraduf/Models/YourTTS_ME_22k-February-07-2023_06+31AM-0000000
Traceback (most recent call last):
  File "/home/caraduf/CoquiTTS/Trainer/trainer/trainer.py", line 1659, in fit
    self._fit()
  File "/home/caraduf/CoquiTTS/Trainer/trainer/trainer.py", line 1614, in _fit
    self.eval_epoch()
  File "/home/caraduf/CoquiTTS/Trainer/trainer/trainer.py", line 1501, in eval_epoch
    outputs,
UnboundLocalError: local variable 'outputs' referenced before assignment

I updated the trainer to the latest version following the instructions for github but the issue still occurs.

Also note that training VITS model against the same dataset (and also the same max value [10 seconds or 10 x 22050]) is working. So it stops only when running the YourTTS recipe. I will try with debug mode ON and see if it shows interesting things.

Here is the adapted recipe :

import os

import torch
from trainer import Trainer, TrainerArgs

from TTS.bin.compute_embeddings import compute_embeddings
from TTS.bin.resample import resample_files
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
from TTS.utils.downloaders import download_vctk

torch.set_num_threads(24)

# pylint: disable=W0105
"""
    This recipe replicates the first experiment proposed in the YourTTS paper (https://arxiv.org/abs/2112.02418).
    YourTTS model is based on the VITS model however it uses external speaker embeddings extracted from a pre-trained speaker encoder and has small architecture changes.
    In addition, YourTTS can be trained in multilingual data, however, this recipe replicates the single language training using the VCTK dataset.
    If you are interested in multilingual training, we have commented on parameters on the VitsArgs class instance that should be enabled for multilingual training.
    In addition, you will need to add the extra datasets following the VCTK as an example.
"""
# Path where you want to save the models outputs (configs, checkpoints and tensorboard logs)
OUT_PATH = os.path.dirname(os.path.abspath(__file__))  # "/raid/coqui/Checkpoints/original-YourTTS/"

# Name of the run for the Trainer
RUN_NAME = "YourTTS_ME_22k"

Me_Rec_1_config = BaseDatasetConfig(
    formatter="ljspeech", dataset_name="ME_Rec1", meta_file_train="metadata.csv", path="/home/caraduf/Datasets/ME_22kHz/Rec_1_LARGE_V2_22.05kHz_dataset", language="fr-fr"
)

Me_Rec_2_config = BaseDatasetConfig(
    formatter="ljspeech", dataset_name="ME_Rec2", meta_file_train="metadata.csv", path="/home/caraduf/Datasets/ME_22kHz/Rec_2_LARGE_V2_22.05kHz_dataset", language="fr-fr"
)

# Add here all datasets configs, in our case we just want to train with the VCTK dataset then we need to add just VCTK. Note: If you want to added new datasets just added they here and it will automatically compute the speaker embeddings (d-vectors) for this new dataset :)
DATASETS_CONFIG_LIST = [
    Me_Rec_1_config,
    Me_Rec_2_config
]

# If you want to do transfer learning and speedup your training you can set here the path to the original YourTTS model
RESTORE_PATH = None  # "/root/.local/share/tts/tts_models--multilingual--multi-dataset--your_tts/model_file.pth"

# This paramter is usefull to debug, it skips the training epochs and just do the evaluation  and produce the test sentences
SKIP_TRAIN_EPOCH = False

# Set here the batch size to be used in training and evaluation
BATCH_SIZE = 32

# Training Sampling rate and the target sampling rate for resampling the downloaded dataset (Note: If you change this you might need to redownload the dataset !!)
# Note: If you add new datasets, please make sure that the dataset sampling rate and this parameter are matching, otherwise resample your audios
SAMPLE_RATE = 22050

# Max audio length in seconds to be used in training (every audio bigger than it will be ignored)
MAX_AUDIO_LEN_IN_SECONDS = 10

# # Define the number of threads used during the audio resampling
# NUM_RESAMPLE_THREADS = 10
# # Check if VCTK dataset is not already downloaded, if not download it
# if not os.path.exists(VCTK_DOWNLOAD_PATH):
#     print(">>> Downloading VCTK dataset:")
#     download_vctk(VCTK_DOWNLOAD_PATH)
#     resample_files(VCTK_DOWNLOAD_PATH, SAMPLE_RATE, file_ext="flac", n_jobs=NUM_RESAMPLE_THREADS)

# # init configs
# vctk_config = BaseDatasetConfig(
#     formatter="vctk",
#     dataset_name="vctk",
#     meta_file_train="",
#     meta_file_val="",
#     path=VCTK_DOWNLOAD_PATH,
#     language="en",
#     ignored_speakers=[
#         "p261",
#         "p225",
#         "p294",
#         "p347",
#         "p238",
#         "p234",
#         "p248",
#         "p335",
#         "p245",
#         "p326",
#         "p302",
#     ],  # Ignore the test speakers to full replicate the paper experiment
# )

### Extract speaker embeddings
SPEAKER_ENCODER_CHECKPOINT_PATH = (
    "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar"
)
SPEAKER_ENCODER_CONFIG_PATH = "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json"

D_VECTOR_FILES = []  # List of speaker embeddings/d-vectors to be used during the training

# Iterates all the dataset configs checking if the speakers embeddings are already computated, if not compute it
for dataset_conf in DATASETS_CONFIG_LIST:
    # Check if the embeddings weren't already computed, if not compute it
    embeddings_file = os.path.join(dataset_conf.path, "speakers.pth")
    if not os.path.isfile(embeddings_file):
        print(f">>> Computing the speaker embeddings for the {dataset_conf.dataset_name} dataset")
        compute_embeddings(
            SPEAKER_ENCODER_CHECKPOINT_PATH,
            SPEAKER_ENCODER_CONFIG_PATH,
            embeddings_file,
            old_spakers_file=None,
            config_dataset_path=None,
            formatter_name=dataset_conf.formatter,
            dataset_name=dataset_conf.dataset_name,
            dataset_path=dataset_conf.path,
            meta_file_train=dataset_conf.meta_file_train,
            meta_file_val=dataset_conf.meta_file_val,
            disable_cuda=False,
            no_eval=False,
        )
    D_VECTOR_FILES.append(embeddings_file)

# Audio config used in training.
audio_config = VitsAudioConfig(
    sample_rate=SAMPLE_RATE,
    hop_length=256,
    win_length=1024,
    fft_size=1024,
    mel_fmin=0.0,
    mel_fmax=None,
    num_mels=80,
)

# Init VITSArgs setting the arguments that is needed for the YourTTS model
model_args = VitsArgs(
    d_vector_file=D_VECTOR_FILES,
    use_d_vector_file=True,
    d_vector_dim=512,
    num_layers_text_encoder=10,
    speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH,
    speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH,
    resblock_type_decoder="1",  # On the paper, we accidentally trained the YourTTS using ResNet blocks type 2, if you like you can use the ResNet blocks type 1 like the VITS model
    # Usefull parameters to enable the Speaker Consistency Loss (SCL) discribed in the paper
    # use_speaker_encoder_as_loss=True,
    # Usefull parameters to the enable multilingual training
    # use_language_embedding=True,
    # embedded_language_dim=4,
)

# General training config, here you can change the batch size and others usefull parameters
config = VitsConfig(
    output_path=OUT_PATH,
    model_args=model_args,
    run_name=RUN_NAME,
    project_name="YourTTS",
    run_description="""
            - Original YourTTS trained using shorter extacts made by new method
        """,
    dashboard_logger="tensorboard",
    logger_uri=None,
    audio=audio_config,
    batch_size=BATCH_SIZE,
    batch_group_size=48,
    eval_batch_size=BATCH_SIZE,
    num_loader_workers=4,
    eval_split_max_size=256,
    print_step=50,
    plot_step=100,
    log_model_step=1000,
    save_step=5000,
    save_n_checkpoints=10,
    save_checkpoints=True,
    target_loss="loss_1",
    print_eval=False,
    use_phonemes=False,
    phonemizer="espeak",
    phoneme_language="fr-fr",
    compute_input_seq_cache=True,
    add_blank=True,
    text_cleaner="multilingual_cleaners",
    characters=CharactersConfig(
        characters_class="TTS.tts.models.vits.VitsCharacters",
        pad="_",
        eos="&",
        bos="*",
        blank=None,
        #characters="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00af\u00b7\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e6\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u00ff\u0101\u0105\u0107\u0113\u0119\u011b\u012b\u0131\u0142\u0144\u014d\u0151\u0153\u015b\u016b\u0171\u017a\u017c\u01ce\u01d0\u01d2\u01d4\u0430\u0431\u0432\u0433\u0434\u0435\u0436\u0437\u0438\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u0444\u0445\u0446\u0447\u0448\u0449\u044a\u044b\u044c\u044d\u044e\u044f\u0451\u0454\u0456\u0457\u0491\u2013!'(),-.:;? ",

        characters="!',-.:?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz «»ÀÇÉÊàâçèéêëîïôùûœ–’…",
        punctuations="!'(),-.:;? ",
        phonemes="",
        is_unique=True,
        is_sorted=True,
    ),
    phoneme_cache_path=None,
    precompute_num_workers=12,
    start_by_longest=True,
    datasets=DATASETS_CONFIG_LIST,
    cudnn_benchmark=False,
    max_audio_len=SAMPLE_RATE * MAX_AUDIO_LEN_IN_SECONDS,
    mixed_precision=True,
    test_sentences=[
        [
            "Il m'a fallu du temps pour créer cette voix alors ma bouche ne restera pas fermée.",
            # "ME",
            None,
            "fr_FR",
        ],
        [
            "Il m'a fallu beaucoup de temps pour développer une voix, et maintenant que je l'ai, je ne vais pas me taire.",
            # "ME",
            None,
            "fr_FR",
        ],
        [
            "Mais son âge rendait cette dernière qualité plus saillante.",
            # "ME",
            None,
            "fr_FR",
        ],
        # [
        #     "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
        #     "VCTK_p277",
        #     None,
        #     "en",
        # ],
        # [
        #     "Be a voice, not an echo.",
        #     "VCTK_p239",
        #     None,
        #     "en",
        # ],
        # [
        #     "I'm sorry Dave. I'm afraid I can't do that.",
        #     "VCTK_p258",
        #     None,
        #     "en",
        # ],
        # [
        #     "This cake is great. It's so delicious and moist.",
        #     "VCTK_p244",
        #     None,
        #     "en",
        # ],
        # [
        #     "Prior to November 22, 1963.",
        #     "VCTK_p305",
        #     None,
        #     "en",
        # ],
    ],
    # Enable the weighted sampler
    use_weighted_sampler=True,
    # Ensures that all speakers are seen in the training batch equally no matter how many samples each speaker has
    weighted_sampler_attrs={"speaker_name": 1.0},
    weighted_sampler_multipliers={},
    # It defines the Speaker Consistency Loss (SCL) α to 9 like the paper
    speaker_encoder_loss_alpha=9.0,
)

# Load all the datasets samples and split traning and evaluation sets
train_samples, eval_samples = load_tts_samples(
    config.datasets,
    eval_split=True,
    eval_split_max_size=config.eval_split_max_size,
    eval_split_size=config.eval_split_size,
)

# Init the model
model = Vits.init_from_config(config)

# Init the trainer and 🚀
trainer = Trainer(
    TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH),
    config,
    output_path=OUT_PATH,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
)
trainer.fit()

To Reproduce

Create a dataset in LJSpeech format (22,05 kHz audios) in French. Adapt dataset config, sample rate in the provided recipe. Launch it.

python3 YourTTS_recipe.py

Expected behavior

The training should go on.

Logs

No response

Environment

{
    "CUDA": {
        "GPU": [
            "NVIDIA GeForce RTX 3090"
        ],
        "available": true,
        "version": "11.7"
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "1.13.1+cu117",
        "Trainer": "v0.0.22",
        "numpy": "1.22.4"
    },
    "System": {
        "OS": "Linux",
        "architecture": [
            "64bit",
            "ELF"
        ],
        "processor": "x86_64",
        "python": "3.10.6",
        "version": "#64-Ubuntu SMP Thu Jan 5 11:43:13 UTC 2023"
    }
}

Additional context

No response

Ca-ressemble-a-du-fake commented 1 year ago

I could correct this by doing the following things :

Now the training is working. So I believe this recipe must be run against a multi speaker datasets. Single speaker dataset may not be supported.

Maybe it can work with 22kHz audio but I did not test it as I only have 16kHz multi speaker datasets and a single one in 22kHz.

chankl3579 commented 1 year ago

I tried to finetune the YourTTS model with my own small dataset and faced the same error as this issue. My dataset includes 256 audio data and I made it in LJSpeech format. Thanks for the suggested solutions above but since changing the source code is not preferable to me, I studied this problem a little bit.

Let me get straight to the point, I think the reason of this error is not about multi-speaker or single-speaker, this issue occurs when the dataset is relatively small. I tried to train from scratch using only the LJSpeech-1.1 dataset but the error did not occur. So we can tell single-speaker format is not the problem.

Then I then made a subset of LJSpeech with only the first 1024 data and train from scratch again, the error is reproduced in this case. From the Python log, we can see the error occurs during the evaluation stage of the training. By default, the evaluation split proportion is 0.01. In this simulation, the size of evaluation set would be 1024*0.01=10, which is smaller than the default batch size 32. By explicitly declaring eval_split_size=32, the problem is solved. Furthermore, it should be aware that, when any of the training data is discarded by MAX_AUDIO_LEN_IN_SECONDS and the size of evaluation set is less than batch size, this problem will happen.

To conclude, this bug occurs when the actual size of evaluation set is less an 1x batch size. The training-evaluation split proportion, discarding of samples, and inappropriate hyperparameters (such as inconsistency between BATCH_SIZE and eval_split_max_size) may cause the problem.