TensorSpeech / TensorFlowTTS

:stuck_out_tongue_closed_eyes: TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2 (supported including English, French, Korean, Chinese, German and Easy to adapt for other languages)
https://tensorspeech.github.io/TensorFlowTTS/
Apache License 2.0
3.85k stars 815 forks source link

Losses are not correct #688

Closed iamanigeeit closed 2 years ago

iamanigeeit commented 3 years ago

There is a problem with calculate_3d_loss in tensorflow_tts.utils.strategy.

The mel ground truth (y_gt) should not be truncated if the mel prediction (y_pred) is shorter. If the prediction is shorter, it should be penalized. One way to do this is to pad the prediction to the ground truth length.

In practice, this rarely happens, because stop_token_loss is wrongly set up and usually causes the model produce output longer than the ground truth. This is also due to truncated y_pred in calculate_2d_loss. Consider the following, where max_mel_length = 3:

stop_token_predictions = [-20, -20, -20, -20, -20, -20, 5, 5, 5, 5]
stop_gts = [0, 0, 0]

Truncating stop_token_predictions will make loss close to 0, although the stop token prediction is totally wrong (it should stop after 6 mel slices, not 3). To make it right, stop_gts should be padded with 1s.

There also needs to be masking in the loss functions, and the training can be a lot faster if we use bucket_by_sequence_length for batching the dataset. I'm currently implementing these.

dathudeptrai commented 3 years ago

@iamanigeeit the len of y_gt and y_pred are always equal. Please read this (https://github.com/TensorSpeech/TensorFlowTTS/pull/455). len(y_gt) > len(y_pred) is only happend when you are using multi-gpu.

iamanigeeit commented 3 years ago

@iamanigeeit the len of y_gt and y_pred are always equal. Please read this (#455). len(y_gt) > len(y_pred) is only happend when you are using multi-gpu.

Thanks for the quick reply! But i am using tacotron2 and there is no duration anywhere.

dathudeptrai commented 3 years ago

@iamanigeeit hi, tacotron2 use teacher forcing so the len of y_gt and y_pred are also equal :D

iamanigeeit commented 3 years ago

@dathudeptrai Since tacotron2 uses teacher forcing, the model seems not to learn the correct duration... i'm getting increasing stop_token_loss in eval, and it's worse if i batch with bucket_by_sequence_length because shorter sequence length means there are less 1s to correct the stop token position. I'll update if i find a solution.

For the mel loss, maybe you can consider this?

def calc_mel_loss(mel_gts, mel_outputs, mel_gts_length):
    num_mels = mel_gts.shape[-1]
    max_gts_length = mel_gts.shape[1]
    max_output_length = mel_outputs.shape[1]

    if max_output_length is not None:
        # Force mel_outputs to be max_gt_length
        if max_output_length > max_gts_length:
            mel_outputs = tf.slice(mel_outputs, [0, 0, 0], [-1, max_gts_length, -1])

        elif max_gts_length > max_output_length:
            pad_length = max_gts_length - max_output_length
            mel_outputs = tf.pad(mel_outputs, paddings=[[0, 0], [0, pad_length], [0, 0]])

    mask = tf.sequence_mask(mel_gts_length, maxlen=max_gts_length, dtype=tf.float32)
    loss = tf.abs(mel_gts - mel_outputs) * tf.expand_dims(mask, axis=-1)
    loss = tf.reduce_sum(loss, axis=[1,2]) / tf.cast(
        mel_gts_length, dtype=tf.float32) / tf.cast(
        num_mels, dtype=tf.float32)

    return loss
iamanigeeit commented 3 years ago

@dathudeptrai @ZDisket I have added preprocessors for Emotion Speech Dataset and VCTK Corpus as well as modified Tacotron2 to accept speaker embeddings + prosody/emotion embeddings according to this paper and alternative GMM attention mechanisms from here.

I can't solve the problem of stop token loss increasing in the baseline multi-speaker Tacotron2. At inference time, this causes the RNN to not terminate about half the time. Screenshot_2021-10-27_13-46-34

Some things i've tried:

Hence, i am moving to Coqui-TTS instead. If you would like to review the work i've done i can submit a pull request. Thanks!

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs.