Closed kumatheworld closed 1 year ago
I changed the class GAN
as follows, introducing the R1 regularization term loss_r1
. I tried enlarging the model but only saw loss_dr
decreasing a lot faster than loss_df
and loss_dg
. I don't think R1 regularization alone would help.
class GAN(GenerativeModel):
def __init__(
self, generator: Generator, discriminator: T2TModule, gamma: float
) -> None:
super().__init__(generator)
self.discriminator = discriminator
self.gamma = gamma
def step(self, x: torch.Tensor) -> Losses:
generator = self.generator
discriminator = self.discriminator
x.requires_grad_()
n = len(x)
discriminator.zero_grad()
logits_dr = discriminator(x)
loss_dr = softplus(-logits_dr).mean()
fake = generator.sample(n)
logits_df = discriminator(fake.detach())
loss_df = softplus(logits_df).mean()
x_grad_real = torch.autograd.grad(logits_dr.sum(), x, create_graph=True)[0]
loss_r1 = x_grad_real.square().sum() / n
loss_d = loss_dr + loss_df + (self.gamma / 2) * loss_r1
loss_d.backward()
generator.zero_grad()
logits_dgf = discriminator(fake)
loss_g = softplus(-logits_dgf).mean()
loss_g.backward()
losses = {
"loss_dr": loss_dr.item(),
"loss_df": loss_df.item(),
"loss_r1": loss_r1.item(),
"loss_g": loss_g.item(),
}
return losses
The current GAN performance looks catastrophic. I tried some bigger models but that didn't work.
Perhaps you want to have a regularization method like the R1 regularization, which I find easy and worth trying.