kumatheworld / fakeglyph

Fancy brand new letters with generative models
MIT License
0 stars 0 forks source link

Stabilize GAN training #3

Closed kumatheworld closed 1 year ago

kumatheworld commented 1 year ago

The current GAN performance looks catastrophic. I tried some bigger models but that didn't work.

samples

Perhaps you want to have a regularization method like the R1 regularization, which I find easy and worth trying.

kumatheworld commented 1 year ago

I changed the class GAN as follows, introducing the R1 regularization term loss_r1. I tried enlarging the model but only saw loss_dr decreasing a lot faster than loss_df and loss_dg. I don't think R1 regularization alone would help.

class GAN(GenerativeModel):
    def __init__(
        self, generator: Generator, discriminator: T2TModule, gamma: float
    ) -> None:
        super().__init__(generator)
        self.discriminator = discriminator
        self.gamma = gamma

    def step(self, x: torch.Tensor) -> Losses:
        generator = self.generator
        discriminator = self.discriminator
        x.requires_grad_()
        n = len(x)

        discriminator.zero_grad()
        logits_dr = discriminator(x)
        loss_dr = softplus(-logits_dr).mean()

        fake = generator.sample(n)
        logits_df = discriminator(fake.detach())
        loss_df = softplus(logits_df).mean()

        x_grad_real = torch.autograd.grad(logits_dr.sum(), x, create_graph=True)[0]
        loss_r1 = x_grad_real.square().sum() / n

        loss_d = loss_dr + loss_df + (self.gamma / 2) * loss_r1
        loss_d.backward()

        generator.zero_grad()
        logits_dgf = discriminator(fake)
        loss_g = softplus(-logits_dgf).mean()
        loss_g.backward()

        losses = {
            "loss_dr": loss_dr.item(),
            "loss_df": loss_df.item(),
            "loss_r1": loss_r1.item(),
            "loss_g": loss_g.item(),
        }
        return losses