AlexiaJM / RelativisticGAN

Code for replication of the paper "The relativistic discriminator: a key element missing from standard GAN"
719 stars 103 forks source link

add Relativism to cycleGAN #16

Closed Auth0rM0rgan closed 5 years ago

Auth0rM0rgan commented 5 years ago

Hey @AlexiaJM ,

Good job to your work. I am trying to add Relativism to cycleGAN but I'm little confused about the way to add Relativism to this GAN, since cycleGAN has 2 Generator and 2 Discriminator.

Generators loss in cyclegan calculated as follow:

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2  

Discriminator loss:

        #  Train Discriminator A
        optimizer_D_A.zero_grad()
        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        #  Train Discriminator B
        optimizer_D_B.zero_grad()
        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

and criterion_GAN is MSELoss().

I modified the code and add relativisim but I am not sure I did it correctly. Generators loss:

   loss_GAN_AB = (torch.mean((D_B(real_A) - torch.mean(D_B(fake_B)) + valid) ** 2) +
                  torch.mean((D_B(fake_B) - torch.mean(D_B(real_A)) - valid) ** 2)) / 2

   loss_GAN_BA = (torch.mean((D_A(real_B) - torch.mean(D_A(fake_A)) + valid) ** 2) +
                  torch.mean((D_A(fake_A) - torch.mean(D_A(real_B)) - valid) ** 2)) / 2

   loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

Discriminators loss:

                optimizer_D_A.zero_grad()
                fake_A_ = fake_A_buffer.push_and_pop(fake_A)

                errD_A = (torch.mean((D_A(real_A) - torch.mean(D_A(fake_A_.detach())) - valid) ** 2) +
                        torch.mean((D_A(fake_A_.detach()) - torch.mean(D_A(real_A)) + valid) **2)) / 2

                errD_A.backward()
                optimizer_D_A.step()

                # Train Second Discriminator (B)
                optimizer_D_B.zero_grad()
                fake_B_ = fake_B_buffer.push_and_pop(fake_B)
                errD_B =(torch.mean((D_B(real_B) - torch.mean(D_B(fake_B_.detach())) - valid) ** 2) +
                        torch.mean((D_B(fake_B_.detach()) - torch.mean(D_B(real_B)) + valid) **2)) / 2
                errD_B.backward()
                optimizer_D_B.step()
                loss_D = (errD_A + errD_B) / 2

I would be appreciated if you can help me to figure out how can I add Relativism to cycleGAN.

Thanks in advance!

AlexiaJM commented 5 years ago

Hi @arminXerror,

I'm assuming fake_A is a sample from B transformed into A and fake_B is a sample from A transformed into B. You train D_A based on real_A and fake_A and D_B based on real_B and fake_B.

So the losses are:

Training D

loss_GAN_A = (torch.mean((D_A(real_A) - torch.mean(D_A(fake_A)) - 1) 2) + torch.mean((torch.mean(D_A(real_A)) - D_A(fake_A) - 1) 2))/2 loss_GAN_B = (torch.mean((D_B(real_B) - torch.mean(D_B(fake_B)) - 1) 2) + torch.mean((torch.mean(D_B(real_B)) - D_B(fake_B) - 1) 2))/2 loss_GAN = (loss_GAN_A + loss_GAN_B) / 2

Training G (non-saturating loss)

loss_GAN_A = (torch.mean((D_A(fake_A) - torch.mean(D_A(real_A)) - 1) 2) + torch.mean((torch.mean(D_A(fake_A)) - D_A(real_A) - 1) 2))/2 loss_GAN_B = (torch.mean((D_B(fake_B) - torch.mean(D_B(real_B)) - 1) 2) + torch.mean((torch.mean(D_B(fake_B)) - D_B(real_B) - 1) 2))/2 loss_GAN = (loss_GAN_A + loss_GAN_B) / 2

Make sure to detach, add G and all these things, This is just to show you the loss functions.

The loss is more clearly written in : https://arxiv.org/pdf/1901.02474. You can also see in there how to easily change it to hinge loss if you want to try later.

Auth0rM0rgan commented 5 years ago

Hi @AlexiaJM ,

Thank you for quick answer.Since there is tanh() layer as the last layer in generators, I need to use Relativistic average LSGAN or Relativistic average HingeGAN.

Is my modification correct for Relativistic average LSGAN ?

G:

   loss_GAN_AB = (torch.mean((D_A(real_A) - torch.mean(D_A(fake_A)) + valid) ** 2) +
                  torch.mean((D_A(fake_A) - torch.mean(D_A(real_B)) - valid) ** 2)) / 2

   loss_GAN_BA = (torch.mean((D_B(real_B) - torch.mean(D_B(fake_B)) + valid) ** 2) +
                  torch.mean((D_B(fake_B) - torch.mean(D_B(real_B)) - valid) ** 2)) / 2

   loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

valid is a matrix with full of ones: valid = torch.ones((real_A.size(0), *patch), requires_grad=False).to(device)

Thanks in advance!

AlexiaJM commented 5 years ago

You can use other loss functions, the generator has no influence on this. I have tanh in my G and I used all sorts of loss functions.

I'm not sure why you want to switch between "+" and "-", but anyhow both options work. You are correct except for a small typo:

loss_GAN_AB = (torch.mean((D_A(real_A) - torch.mean(D_A(fake_A)) + valid) 2) + torch.mean((D_A(fake_A) - torch.mean(DA(realA)) - valid) 2)) / 2

loss_GAN_BA = (torch.mean((D_B(real_B) - torch.mean(D_B(fake_B)) + valid) 2) + torch.mean((D_B(fake_B) - torch.mean(D_B(real_B)) - valid) 2)) / 2

loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

Auth0rM0rgan commented 5 years ago

Thanks. In the Readme.md, you are switching between "+" and "-" and I follow the order you wrote the loss function in Readme :)


# Generator loss (You may want to resample again from real and fake data)
errG = (torch.mean((y_pred - torch.mean(y_pred_fake) + y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - y) ** 2))/2
errG.backward()