TensorSpeech / TensorFlowTTS

:stuck_out_tongue_closed_eyes: TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2 (supported including English, French, Korean, Chinese, German and Easy to adapt for other languages)
https://tensorspeech.github.io/TensorFlowTTS/
Apache License 2.0
3.85k stars 815 forks source link

[baker fastspeech2 MFA] MFA aligned durations gives a worse speech quality than tacotron2 extracted durations #461

Closed ronggong closed 3 years ago

ronggong commented 3 years ago

the issue

I modified the libritts example to use MFA alignment extracted durations to train a baker fastspeech2 model, it sounds much worse and noise than the tacotron2 extracted duration trained model. Audio samples: audios.zip

The transcription, target ids for training both models are exactly the same. When doing the MFA alignment, pauses are kept in the transcription https://github.com/TensorSpeech/TensorFlowTTS/blob/cca12a3843dd3b4d5dc2608714f76d2d0d4c4255/tensorflow_tts/processor/baker.py#L33

I also did another MFA alignment with the pauses removed from the transcription, but the model trained by such a duration doesn't sound good neither. mfa_duration_nopauses.zip

It is quite obvious the reason is the duration extracted from the MFA alignment has problem. Anyone has experience to solve this?

altangerelc commented 3 years ago

@ronggong What about your training samples? did they have silence at the beginning and end ? I also tried importing alignment from HTS label file (with ljspeech format). But FastSpeech2 model trained gives bad result. I suspect some accuracy issues converting to -durations from those of MFA. Or MFA and tacortron2 are taking pauses into account internally. I am coming to a conclusion that we do not need to explicitly define pauses, which I should confirm after my test trainings finish. Also from these 3 outcomes which one has the most resemblance of the initial silences of the original training samples? As you see mfa_duration_nopauses has shorter silence duration at the beginning.

ronggong commented 3 years ago

@altangerelc

did they have silence at the beginning and end ?

The wav for training has been trimmed the beginning by the preprocess tool. However, I just noticed the ending silence has not been trimmed, because the transcription has both "sil" and "eos" at the end of the label, and only the "eos" has been trimmed. In tacotron2 extracted durations case, the trimming is done by librosa, which detecting the silence from the audio itself. In MFA durations case, they are trimmed by the alignment.

We can hack the preprocess this part to trim all the starting and trailing silence, e.g. to put below code in a while loop https://github.com/TensorSpeech/TensorFlowTTS/blob/cca12a3843dd3b4d5dc2608714f76d2d0d4c4255/tensorflow_tts/bin/preprocess.py#L145-L170

I suspect some accuracy issues converting to -durations from those of MFA

I compared two duration examples extracted by MFA and by tacotron2. They have the same total length, but individual phone length are different.

MFA and tacortron2 are taking pauses into account internally.

MFA takes into account of the pauses. It marks the pauses as "sil", "sp" or "spn" phones. Since MFA uses Kaldi, I guess they are the states in its FST graph.

I am coming to a conclusion that we do not need to explicitly define pauses, which I should confirm after my test trainings finish.

It's possible.

Also from these 3 outcomes which one has the most resemblance of the initial silences of the original training samples? As you see mfa_duration_nopauses has shorter silence duration at the beginning

The _mfa_durationnopauses is generated by a model trained with the alignment that doesn't use pause tokens. The initial silence is in the inference token.

ronggong commented 3 years ago

@altangerelc problem solved. The key is really to trim all the beginning and trailing silence. Some MFA alignment has the text "SIL ... SIL EOS", the current preprocess script can not trim the sil before eos.

altangerelc commented 3 years ago

@ronggong Thanks for the update. I got the clue from your posts, for my issue #456. Yes the issue was the trimming silences. Thanks again.