yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
5k stars 423 forks source link

stage1 training issue #175

Closed zhiqiuiyiye closed 8 months ago

zhiqiuiyiye commented 11 months ago

Hi, I'm training the styletts2 on a new language Thai, when I trained epoch 7 , I found the loss were been Nan, and the g_loss seems increasing when traing. I want to know what will cause this problem. Here is my log, training loss. 微信截图_20231228094528 微信截图_20231228094433

yl4579 commented 10 months ago

Sorry for the late reply. I was quite busy recently. Have you checked #10 and #11? Did you use mixed precision as well?

zhiqiuiyiye commented 10 months ago

thanks for your reply, I have fixed this issue, may caused by too small batch size

akshatgarg99 commented 10 months ago

Were you able to do it? I was trying to train but was facing some issue. Can we discuss?

RillmentGames commented 9 months ago

Same issue with batch size 2, generator loss can reach about 100 and then it Nan's. (EDIT: Didn't work!) I have a preliminary solution, still testing though but based on https://github.com/yl4579/StyleTTS2/issues/11#issuecomment-1752326746 it seems to be discriminator overfitting. So I am trying to force the discriminators weight decay to a high value to prevent overfitting, in train_first:

for module in ["mpd", "msd"]:
    for g in optimizer.optimizers[module].param_groups:
        g['weight_decay'] = 0.1

and also lowering the feature discriminator gain by premultiplying by 0.5, in losses.py

class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd

    def forward(self, y, y_hat):
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)

        loss_gen_all = loss_gen_s + loss_gen_f + 0.5*loss_fm_s + 0.5*loss_fm_f + loss_rel

        return loss_gen_all.mean()

At first I tried decay = 0.01 and gains 1.0,1.0 but that only delayed the problem. Then I tried decay = 1.0 and gains 0.1,0.1 and that seemed to prevent Nan but the audio quality wasn't good. So now I am trying decay = 0.1 and gains 0.5,0.5. I should be able to report back the results in a few days.

RillmentGames commented 9 months ago

No that didn't work :( the loss made some strange moves and eventually ended with Nan.

RillmentGames commented 9 months ago

Integrating PhaseAug and using batch_percentage=1.0 with Batch=2, fixed it for me. PhaseAug tries to address the overfitting issue by randomly rotating the phase of each frequency bin. The gen error still creeps up but very slowly now and audio quality becomes quite nice after 2 epochs:

    aug = PhaseAug()
    gl = GeneratorLoss(model.mpd, model.msd, aug).to(device)
    dl = DiscriminatorLoss(model.mpd, model.msd, aug).to(device)
class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd
        self.aug = aug  

    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat)                 #               <--- Augment here
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)

        loss_gen_all = loss_gen_s + loss_gen_f + 1.0*loss_fm_s + 1.0*loss_fm_f + loss_rel

        return loss_gen_all.mean()

class DiscriminatorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(DiscriminatorLoss, self).__init__()
        self.aug = aug
        self.mpd = mpd
        self.msd = msd

    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat.detach())    #                   <--- Augment here
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
        loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
        loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)

        loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)

        d_loss = loss_disc_s + loss_disc_f + loss_rel

        return d_loss.mean()