NVIDIA / DeepLearningExamples

State-of-the-Art Deep Learning scripts organized by models - easy to train and deploy with reproducible accuracy and performance on enterprise-grade infrastructure.
13.01k stars 3.13k forks source link

[Fastpitch] Multi-speaker model changes output speaker identity for different texts #1084

Open adrianastan opened 2 years ago

adrianastan commented 2 years ago

Hi,

We are trying to train a multi-speaker model starting from the LibriTTS data and using the latest FastPitch commit. We selected the 50 speakers which have the most utterances in the dataset, and removed the single-word ones (resulting in around 8400 samples in the training set). The model was trained for 1500 epochs and the output quality is quite alright.

However, when using the same speaker ID and synthesising multiple texts, we get slightly different output identities. Did anyone else encounter this problem? Is there anything we can tweak in the model to make sure this does not happen? Or what else can we do to enforce a single output speaker ID across synthesised utterances?

Thanks, Adriana

alancucki commented 2 years ago

Hi @adrianastan

Sorry for replying late. I haven't got much experience with that many speakers, but I'd try to add capacity to the model and look at class imbalance - if the speakers are indeed imbalanced, you can try weighting the speakers somehow, or take the easy route and just repeat those speakers' utterances in the filelist.

adrianastan commented 2 years ago

Hi, thank you for your reply.

We also trained a model with the exact same number of utterances and the same text from 37 different speakers, and the results are the same. I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity (https://github.com/NVIDIA/DeepLearningExamples/blob/de507d9fecfbdd50ad001bdb15e89f8eae46871e/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207)

So I am wondering if anybody else tried the multispeaker model and found alternative ways of performing this conditioning.

Thanks!

alancucki commented 2 years ago

I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity

That might be the case. Positional embeddings are sometimes added also in between the layers to keep the positional information from fading out. Doing the same with speaker embedding might be worth a shot.

How many utterances do you roughly have per speaker? Are there both male and female speakers?

adrianastan commented 2 years ago

I now added the embedding to condition the decoder as well here: https://github.com/NVIDIA/DeepLearningExamples/blob/de507d9fecfbdd50ad001bdb15e89f8eae46871e/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py#L314 . But the results aren't any better.

We use both male and female speakers, ranging from 137 to 321 utterances. For the model where we used the same data and same number of utterances, there were 638 utterances from each speaker (37 speakers in total).

dsplog commented 2 years ago

I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity

That might be the case. Positional embeddings are sometimes added also in between the layers to keep the positional information from fading out. Doing the same with speaker embedding might be worth a shot.

@alancucki : was trying out this comment on adding speaker embedding per layer. did as below.

--- a/fastpitch/transformer.py
+++ b/fastpitch/transformer.py
@@ -211,6 +211,7 @@ class FFTransformer(nn.Module):

         for layer in self.layers:
             out = layer(out, mask=mask)
+            out += conditioning

         # out = self.drop(out)
         return out, mask

is that the way you were thinking?

alancucki commented 2 years ago

Exactly! Did the quality improve?

dsplog commented 2 years ago

broadly i can say yes. for few speakers, it gives a noticeable improvement.

however, there are speakers still not getting captured well. it maybe due to not having sufficient data for those speakers. checking still.... :-)

rygopu commented 2 years ago

Hi,

I tried multi speaker Fastpitch for 2200 Epochs (2 speaker M, F each having 15,000 sentences per speaker , total of 50 hours of voice data). But the output is just noise & the model isn't learning. Whereas I was able to train the model from scratch for single speaker and the output was quite alright.

could anyone help me out here? Is there something Important to configure when training Multispeaker FastPitch that I missed out?

Thank you

adrianastan commented 2 years ago

We also used external speaker embeddings (derived from SpeechBrain's model: https://speechbrain.github.io/) as opposed to having FastPitch learn them. This helped a bit, but it still fails at generating short utterances and sometimes even longer ones in the desired speaker's identity.

rygopu commented 2 years ago

Thank you ! I'll check it out. Just to clarify, Did you use external speaker embedding because the output from FastPitch was purely noise or did you use it to Improve utterance pronunciations?

adrianastan commented 2 years ago

Just to improve the speaker control.

alancucki commented 2 years ago

@rygopu The model should not fail completely; even after ~100 epochs you should be able to synthesize a fairly intelligible output. I'd suggest a bit of debugging before drawing conclusions.

In particular please check the data pre-processing pipeline, and make sure that your speaker IDs do get loaded and used by the model. Also, the alignment module might fail to converge for whatever reason.

rygopu commented 2 years ago

Thanks @alancucki & @adrianastan. I retrained FastPitch (2 speakers: 13,000 sentences per speaker, 4500 epochs) and the output is still purely noise [Debug: I tried loading different checkpoints and it seems that the right speaker embeddings are loaded by the model and I could see the embeddings being learnt as the model training progress, but the still the output is same as previous runs] Do you have any leads that I could try?

adrianastan commented 2 years ago

Is it noise-noise or speech-like noise? Is the symbol list you use at training the same as the one used at inference? Is the transcription correct and aligned with the audio?

rygopu commented 2 years ago
adrianastan commented 2 years ago

I am afraid I cannot share the dataset, but I did downsample the audio, trimmed the silence and normalised the volume.

Maybe you can try using one of the single-speaker models as starting point for your multi-speaker one.

rygopu commented 2 years ago

Hi @adrianastan , thank you. I trained Multi-Speaker Fastpitch with the same dataset (Issue was related to pre-processing, downsampling using ffmpeg rather than using librosa has solved it). Also when I run Inference for multiple texts for the same speaker, output quality and Identity seems alright. (my training dataset consists of 15,000 sentences per speaker for a total of 3 different speakers)

adrianastan commented 2 years ago

Thanks for your reply! We are using 50 speakers with 200 utts/speaker, and it still changes the identity. We are now retraining using the ideas here: https://github.com/NVIDIA/DeepLearningExamples/issues/707#issuecomment-727021066

adrianastan commented 1 year ago

Hi, @alancucki,

So we tried all the methods mentioned so far:

But there are still no improvements. The identity still changes for different text inputs. Are there any other ideas you think are worth trying out?

Thanks, Adriana

martinvk1 commented 1 year ago

@adrianastan, have you tried concatenating the speaker embeddings to the text encoding (by repeating it for each symbol)?

adrianastan commented 1 year ago

But this is what happens now in FastPitch: https://github.com/NVIDIA/DeepLearningExamples/blob/afea561ecff80b82f17316a0290f6f34c486c9a5/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207

martinvk1 commented 1 year ago

@adrianastan No, the speaker embedding is summed with the input and positional encoding, not concatenated. This kind of summation should be acceptable for positional encoding, but it is not suited for adding a speaker embedding. I think it would work much better to concatenate, similar to what has been done in multispeaker Tacotron in the past (for example here: https://github.com/CorentinJ/Real-Time-Voice-Cloning/blob/98d0ca4d4d140a4bb6bc7d54c84b1915a79041d5/synthesizer/models/tacotron.py#L62)

After concatenating the dimensionality will change, so other parameters need to be adjusted accordingly. Alternatively, a Linear layer could project back down into the original dimension. Not sure how well that would work.

adrianastan commented 1 year ago

Ok, got it, still this only means that instead of having equal weights in the summation, the network learns the individual summation weights.

So did you try this and got better results? Thanks!

martinvk1 commented 1 year ago

@adrianastan In my experiment I increased the dimensionality of the encoder to fit the embedded symbols with speaker information concatenated to them. That way, the encoder receives intact and clearly separated features and can decide how to deal with it. It may be better to add speaker information after symbol encoding or even later, depending on the goal. So far, I haven't had any problems with speaker similarity, but I ran into other challenges with multispeaker FastPitch, such as needing mean/std pitch values for every speaker in the dataset, and also to make sure that each speaker has both short, medium and long utterances. I find that transformers generalize badly to input lengths they have not explicitly been trained on for a given speaker.

adrianastan commented 1 year ago

Ok, great, I will give it a try. Thanks!

martinvk1 commented 1 year ago

@adrianastan Hope it works out for you. Just wanted to add that I got great results by concatenating the speaker embedding directly to the input of the 1) pitch predictor 2) duration predictor 3) energy predictor and 4) decoder, like below. I am using a speaker_embedding_dim of 256, symbol_embedding_dim of 512 which means that these layers are now 768.


--- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
+++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
@@ -262,14 +262,17 @@ class FastPitch(nn.Module):
             spk_emb.mul_(self.speaker_emb_weight)

         # Input FFT
-        enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
+        enc_out, enc_mask = self.encoder(inputs, conditioning=0) # Do not condition here
+
+        spk_emb_repeated = spk_emb.repeat(1, enc_out.shape[1], 1)
+        enc_out_spk = torch.cat([enc_out, spk_emb_repeated], dim=2)

         # Predict durations
-        log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
+        log_dur_pred = self.duration_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
         dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)

         # Predict pitch
-        pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
+        pitch_pred = self.pitch_predictor(enc_out_spk, enc_mask).permute(0, 2, 1) # Condition here

         # Alignment
         text_emb = self.encoder.word_emb(inputs)
@@ -301,7 +304,7 @@ class FastPitch(nn.Module):

         # Predict energy
         if self.energy_conditioning:
-            energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1)
+            energy_pred = self.energy_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here

             # Average energy over characters
             energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt)
@@ -317,8 +320,11 @@ class FastPitch(nn.Module):
         len_regulated, dec_lens = regulate_len(
             dur_tgt, enc_out, pace, mel_max_len)

+        spk_emb_repeated = spk_emb.repeat(1, len_regulated.shape[1], 1)
+        len_regulated_spk = torch.cat([len_regulated, spk_emb_repeated], dim=2)
+
         # Output FFT
-        dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
+        dec_out, dec_mask = self.decoder(len_regulated_spk, dec_lens) # Condition here
         mel_out = self.proj(dec_out)
         return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred,
                 pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard,
dsplog commented 1 year ago

ah, thanks for sharing the details. will try this out.

Slava715 commented 1 year ago

@adrianastan Hope it works out for you. Just wanted to add that I got great results by concatenating the speaker embedding directly to the input of the 1) pitch predictor 2) duration predictor 3) energy predictor and 4) decoder, like below. I am using a speaker_embedding_dim of 256, symbol_embedding_dim of 512 which means that these layers are now 768.

--- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
+++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
@@ -262,14 +262,17 @@ class FastPitch(nn.Module):
             spk_emb.mul_(self.speaker_emb_weight)

         # Input FFT
-        enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
+        enc_out, enc_mask = self.encoder(inputs, conditioning=0) # Do not condition here
+
+        spk_emb_repeated = spk_emb.repeat(1, enc_out.shape[1], 1)
+        enc_out_spk = torch.cat([enc_out, spk_emb_repeated], dim=2)

         # Predict durations
-        log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
+        log_dur_pred = self.duration_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
         dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)

         # Predict pitch
-        pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
+        pitch_pred = self.pitch_predictor(enc_out_spk, enc_mask).permute(0, 2, 1) # Condition here

         # Alignment
         text_emb = self.encoder.word_emb(inputs)
@@ -301,7 +304,7 @@ class FastPitch(nn.Module):

         # Predict energy
         if self.energy_conditioning:
-            energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1)
+            energy_pred = self.energy_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here

             # Average energy over characters
             energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt)
@@ -317,8 +320,11 @@ class FastPitch(nn.Module):
         len_regulated, dec_lens = regulate_len(
             dur_tgt, enc_out, pace, mel_max_len)

+        spk_emb_repeated = spk_emb.repeat(1, len_regulated.shape[1], 1)
+        len_regulated_spk = torch.cat([len_regulated, spk_emb_repeated], dim=2)
+
         # Output FFT
-        dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
+        dec_out, dec_mask = self.decoder(len_regulated_spk, dec_lens) # Condition here
         mel_out = self.proj(dec_out)
         return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred,
                 pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard,

Hi @martinvk1, I'm trying to change the model as per your advice. But I had issues with the model size and I replaced --in-fft-output-size with twice the size but still getting the size error at: https://github.com/NVIDIA/DeepLearningExamples/blob/afea561ecff80b82f17316a0290f6f34c486c9a5/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207 Please tell us in more detail how you changed the encoder.