yl4579 / StyleTTS

Official Implementation of StyleTTS
MIT License
387 stars 62 forks source link

Need help for training #50

Closed nhanhttrong closed 1 year ago

nhanhttrong commented 1 year ago

I'm pretraining your model on the vivo dataset (Vietnamese) but the results are not what I expected. Here is the original audio: https://drive.google.com/file/d/12mZdg8yVhgQj35Vt3thoxIK44_jWWaCJ/view?usp=sharing and here is the result: https://drive.google.com/file/d/1UOuUHrxiR5DvF1MrpccMZ6bdwKyfO2AE/view?usp=sharing

This is the loss during training stage 1 and stage 2 image

image

p/s: I used the ASR of the original article to train Vietnamese. I wonder if it has any problems because during the training stage 1 there were quite a lot of keyerror errors. Thank you very much

yl4579 commented 1 year ago

I believe the problem lies in the duration loss, it somehow fluctuates between 1 and 0.6. I think the text aligner is probably fine. Could you check if the first stage of the model sounds good?

nhanhttrong commented 1 year ago

Can you provide code to test stage 1? and I have another question, while preprocessing data i use phonemizer to convert text in data to IPA, but all of most IPA which was create by phonemes, aren't exist in your code, so during training i get this status and it print text line which mismatch and loss still working but i wonder how it do that image

and is raise keyerror in this code is effect to train ? image

Thank all your reply, hope you have a good day

yl4579 commented 1 year ago

You can use the following code for testing (under the inference notebook) from the validation loss computation code:

from meldataset import build_dataloader
train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)
train_dataloader = build_dataloader(train_list,
                                    batch_size=batch_size,
                                    num_workers=8,
                                    dataset_config={},
                                    device=device)

val_dataloader = build_dataloader(val_list,
                                  batch_size=batch_size,
                                  validation=True,
                                  num_workers=2,
                                  device=device,
                                  dataset_config={})

_, batch = next(enumerate(train_dataloader)) # can also be val_dataloader 
batch = [b.to(device) for b in batch]
texts, input_lengths, mels, mel_input_length = batch

with torch.no_grad():
                    mask = length_to_mask(mel_input_length // (2 ** model.text_aligner.n_down)).to('cuda')
                    m = length_to_mask(input_lengths)
                    ppgs, s2s_pred, s2s_attn_feat = model.text_aligner(mels, mask, texts)

                    s2s_attn_feat = s2s_attn_feat.transpose(-1, -2)
                    s2s_attn_feat = s2s_attn_feat[..., 1:]
                    s2s_attn_feat = s2s_attn_feat.transpose(-1, -2)

                    with torch.no_grad():
                        text_mask = length_to_mask(input_lengths).to(texts.device)
                        attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
                        attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
                        attn_mask = (attn_mask < 1)

                    s2s_attn_feat.masked_fill_(attn_mask, -float("inf"))

                    if TMA_CEloss:
                        s2s_attn = F.softmax(s2s_attn_feat, dim=1) # along the mel dimension
                    else:
                        s2s_attn = F.softmax(s2s_attn_feat, dim=-1) # along the text dimension

                    # get monotonic version 
                    with torch.no_grad():
                        mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** model.text_aligner.n_down))
                        s2s_attn_mono = maximum_path(s2s_attn, mask_ST)

                    s2s_attn = torch.nan_to_num(s2s_attn)

                # encode
                t_en = model.text_encoder(texts, input_lengths, m)
                asr = (t_en @ s2s_attn_mono)

                # get clips
                mel_len = int(mel_input_length.min().item() / 2 - 1)
                en = []
                gt = []
                for bib in range(len(mel_input_length)):
                    mel_length = int(mel_input_length[bib].item() / 2)

                    random_start = np.random.randint(0, mel_length - mel_len)
                    en.append(asr[bib, :, random_start:random_start+mel_len])
                    gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
                en = torch.stack(en)
                gt = torch.stack(gt).detach()

                with torch.no_grad():
                    F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
                    F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()

                # reconstruct
                s = model.style_encoder(gt.unsqueeze(1))
                real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
                mel_rec = model.decoder(en, F0_real, real_norm, s)
                mel_rec = mel_rec[..., :gt.shape[-1]]

synthesized = []
for idx in range(mel_rec.size(0)):
                with torch.no_grad():
                                # synthesize into waveforms
                                c = mel_rec[idx].squeeze()
                                y_g_hat = generator(c.unsqueeze(0))
                                y_out = y_g_hat.squeeze().cpu().numpy()
                synthesized.append(y_out)

import IPython.display as ipd
for wave in synthesized:
    display(ipd.Audio(wave, rate=24000))

As for the IPA error, it seems like only a few characters are not in the dictionary and they are automatically ignored by the text cleaner. You can do print(char) instead of print(index) to see which specific character.

christopherohit commented 1 year ago

So does the IPA error affect training? If yes, do I need to retrain with the more extensive IPA set through text-aligner?

Thank for your reply

yl4579 commented 1 year ago

@christopherohit If these characters are not essential for reconstruction then it doesn’t matter, because the text aligner is eventually finetuned with your new dataset and unseen characters of pretrained models will be learned during TMA training. But if these characters are important for reconstruction (like pauses, breaths, laughs etc.) then you need to retrain.