Closed zhiqiuiyiye closed 8 months ago
Sorry for the late reply. I was quite busy recently. Have you checked #10 and #11? Did you use mixed precision as well?
thanks for your reply, I have fixed this issue, may caused by too small batch size
Were you able to do it? I was trying to train but was facing some issue. Can we discuss?
Same issue with batch size 2, generator loss can reach about 100 and then it Nan's. (EDIT: Didn't work!) I have a preliminary solution, still testing though but based on https://github.com/yl4579/StyleTTS2/issues/11#issuecomment-1752326746 it seems to be discriminator overfitting. So I am trying to force the discriminators weight decay to a high value to prevent overfitting, in train_first:
for module in ["mpd", "msd"]:
for g in optimizer.optimizers[module].param_groups:
g['weight_decay'] = 0.1
and also lowering the feature discriminator gain by premultiplying by 0.5, in losses.py
class GeneratorLoss(torch.nn.Module):
def __init__(self, mpd, msd):
super(GeneratorLoss, self).__init__()
self.mpd = mpd
self.msd = msd
def forward(self, y, y_hat):
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + 0.5*loss_fm_s + 0.5*loss_fm_f + loss_rel
return loss_gen_all.mean()
At first I tried decay = 0.01 and gains 1.0,1.0 but that only delayed the problem. Then I tried decay = 1.0 and gains 0.1,0.1 and that seemed to prevent Nan but the audio quality wasn't good. So now I am trying decay = 0.1 and gains 0.5,0.5. I should be able to report back the results in a few days.
No that didn't work :( the loss made some strange moves and eventually ended with Nan.
Integrating PhaseAug and using batch_percentage=1.0 with Batch=2, fixed it for me. PhaseAug tries to address the overfitting issue by randomly rotating the phase of each frequency bin. The gen error still creeps up but very slowly now and audio quality becomes quite nice after 2 epochs:
...
aug = PhaseAug()
gl = GeneratorLoss(model.mpd, model.msd, aug).to(device)
dl = DiscriminatorLoss(model.mpd, model.msd, aug).to(device)
...
class GeneratorLoss(torch.nn.Module):
def __init__(self, mpd, msd, aug):
super(GeneratorLoss, self).__init__()
self.mpd = mpd
self.msd = msd
self.aug = aug
def forward(self, y, y_hat):
y, y_hat = self.aug.forward_sync(y, y_hat) # <--- Augment here
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + 1.0*loss_fm_s + 1.0*loss_fm_f + loss_rel
return loss_gen_all.mean()
class DiscriminatorLoss(torch.nn.Module):
def __init__(self, mpd, msd, aug):
super(DiscriminatorLoss, self).__init__()
self.aug = aug
self.mpd = mpd
self.msd = msd
def forward(self, y, y_hat):
y, y_hat = self.aug.forward_sync(y, y_hat.detach()) # <--- Augment here
# MPD
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
d_loss = loss_disc_s + loss_disc_f + loss_rel
return d_loss.mean()
Hi, I'm training the styletts2 on a new language Thai, when I trained epoch 7 , I found the loss were been Nan, and the g_loss seems increasing when traing. I want to know what will cause this problem. Here is my log, training loss.