dorarad / gansformer

Generative Adversarial Transformers
MIT License
1.32k stars 149 forks source link

Two typos in pytorch_version/training/loss.py #47

Open harnvo opened 1 year ago

harnvo commented 1 year ago
                loss_D_real = 0
                if D_main:
                    if self.d_loss == "logistic":
                        loss_D_real = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    elif self.d_loss == "hinge":
                        loss_D_real = torch.clamp(1.0 - real_logits, min = 0)
                    elif self.d_loss == "wgan":
                        loss_D_real = -real_logits + tf.square(real_logits) * wgan_epsilon

                    training_stats.report("Loss/D/loss", loss_D_gen + loss_D_real)

In line 142 of loss.py, it should be loss_D_real = -real_logits + torch.square(real_logits) * self.wgan_epsilon instead.