p0p4k / vits2_pytorch

unofficial vits2-TTS implementation in pytorch
https://arxiv.org/abs/2307.16430
MIT License
465 stars 81 forks source link

about duration-discriminator training objective #59

Closed choiHkk closed 9 months ago

choiHkk commented 9 months ago

Thank you for your hard work. I have a question while attempting to train with your code.

During the training of the duration predictor, I noticed that the "loss_dur" fluctuates significantly compared to previous work. Upon investigation, I found that "grad_norm_dur_disc" is spiking very high. In my opinion, this might be due to the adversarial loss being calculated for a single batch, which is much larger compared to the weights, especially in contrast to the few convolution layers in the discriminator.

As far as I know, in HiFiGAN, the discriminator is composed of several sub-discriminators. Therefore, I understand that there is a for loop inside the "discriminator_loss" and "generator_loss" functions to calculate the loss for each sub-discriminator.

Since the "DurationDiscriminator" you implemented does not consist of sub-discriminator layers, when calculating the loss in the "discriminator_loss" and "generator_loss" functions, it is computed for a single batch size and then summed without any scaling.

In my opinion, this might make the training of the "DurationDiscriminator", which is composed of very small parameters, unstable. I'm curious if this was intentional, to calculate it for a single batch size without scaling. If not, I'm also wondering if it would be acceptable to input in list form within the append() in the discriminator forward pass. Currently, I'm training the model with the relu non-linearity and layernorm that you have written but commented out and the list form. If i get good result, i will share it this issue.

class DurationDiscriminator(nn.Module):  # vits2
    # TODO : not using "spk conditioning" for now according to the paper.
    # Can be a better discriminator if we use it.
    def __init__(
        self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
    ):
        super().__init__()

        self.in_channels = in_channels
        self.filter_channels = filter_channels
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.gin_channels = gin_channels

        self.conv_1 = nn.Conv1d(
            in_channels, filter_channels, kernel_size, padding=kernel_size // 2
        )
        self.norm_1 = modules.LayerNorm(filter_channels)
        self.conv_2 = nn.Conv1d(
            filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
        )
        self.norm_2 = modules.LayerNorm(filter_channels)
        self.dur_proj = nn.Conv1d(1, filter_channels, 1)

        self.pre_out_conv_1 = nn.Conv1d(
            2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
        )
        self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
        self.pre_out_conv_2 = nn.Conv1d(
            filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
        )
        self.pre_out_norm_2 = modules.LayerNorm(filter_channels)

        # if gin_channels != 0:
        #   self.cond = nn.Conv1d(gin_channels, in_channels, 1)

        self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())

    def forward_probability(self, x, x_mask, dur, g=None):
        dur = self.dur_proj(dur)
        x = torch.cat([x, dur], dim=1)
        x = self.pre_out_conv_1(x * x_mask)
        x = torch.relu(x)
        x = self.pre_out_norm_1(x)
        x = self.pre_out_conv_2(x * x_mask)
        x = torch.relu(x)
        x = self.pre_out_norm_2(x)
        x = x * x_mask
        x = x.transpose(1, 2)
        output_prob = self.output_layer(x)
        return output_prob

    def forward(self, x, x_mask, dur_r, dur_hat, g=None):
        x = torch.detach(x)
        # if g is not None:
        #   g = torch.detach(g)
        #   x = x + self.cond(g)
        x = self.conv_1(x * x_mask)
        x = torch.relu(x)
        x = self.norm_1(x)
        x = self.conv_2(x * x_mask)
        x = torch.relu(x)
        x = self.norm_2(x)

        output_probs = []
        for dur in [dur_r, dur_hat]:
            output_prob = self.forward_probability(x, x_mask, dur, g)
            output_probs.append([output_prob])

        return output_probs
choiHkk commented 9 months ago

I think this method is working. After applying audio preprocessing, modifying the residual coupling layer, and adjusting the "discriminator" what i mentioned it before, it seems that meaningful results are coming out from the training.

I will continue the training and if I find a useful checkpoint, I will share it with you.

image

p0p4k commented 9 months ago

Very nice insight; I should have had been more careful earlier. I think one more thing to fix right here is when using sdp, we need to send in the noise input of sdpto the discriminator as well. Can you send a PR of your discriminator and I will merge it? Thanks a lot!

choiHkk commented 9 months ago

@p0p4k Of course. But I will make sure to modify it to be compatible with the existing functionality since there might be conflicts due to the various changes. After that, I will send a pull request.

According to the author's paper, there is no separated noise to the discriminator directly. Did you mean that you want to experiment with a different noise contrast based on the paper? image

p0p4k commented 9 months ago

Ah, my bad. I was thinking about something else. Ignore my previous comment.

p0p4k commented 9 months ago

Also about code breaking changes, just make it dur_disc_2.

choiHkk commented 9 months ago

@p0p4k it's ok kkk.

Could I make the necessary adjustments after adding it to the config so that it can be reflected in the training process? I will check for conflicts based on the most recent branch and send a pull request as soon as possible.

p0p4k commented 9 months ago

Yes, do what you think is best. Thank you for your efforts.

choiHkk commented 9 months ago

@p0p4k I just sent a pull request. I have verified that both training and inference are proceeding correctly. One concern is that I did not include any changes to the requirements. However, if you need it, I will leave the changes here.