Zhendong-Wang / Diffusion-GAN

Official PyTorch implementation for paper: Diffusion-GAN: Training GANs with Diffusion
MIT License
625 stars 67 forks source link

Use Diffusion-GAN in Other GAN Architecture #27

Open RisabBiswas opened 1 year ago

RisabBiswas commented 1 year ago

Hello @Zhendong-Wang and Team,

I would like to firstly say that it's a great work! Thank you for sharing the code. I am trying to use Diffusion-GAN in a GAN architecture for image enhancement. Can you please help me by letting me know how do I use the three steps mentioned for Simple Plug-in by you in the readme in the below code -

for epoch in range(num_epochs):
    for n_batch, (blur_batch, clean_batch) in enumerate(data_loader):

        real_data = clean_batch.float().cuda()
        noised_data = blur_batch.float().cuda()

        # 1. Train Discriminator
        # Generate fake data
        fake_data = generator(noised_data)

        # Reset gradients
        d_optimizer.zero_grad()

        # 1.1 Train on Real Data
        prediction_real = discriminator(real_data, noised_data)

        # Calculate error and backpropagate
        real_data_target = torch.ones_like(prediction_real)
        loss_real = loss1(prediction_real, real_data_target)

        # 1.2 Train on Fake Data, you would need to add one more component
        prediction_fake = discriminator(fake_data, noised_data)

        # Calculate error and backpropagate
        fake_data_target = torch.zeros_like(prediction_real)
        loss_fake = loss1(prediction_fake, fake_data_target)

        loss_d = (loss_real + loss_fake)/2
        loss_d.backward(retain_graph=True)

        # 1.3 Update weights with gradients
        d_optimizer.step()

        # 2. Train Generator
        g_optimizer.zero_grad()

        # Sample noise and generate fake data
        prediction = discriminator(fake_data, real_data)

        # Calculate error and backpropagate
        real_data_target = torch.ones_like(prediction)
        #import pdb; pdb.set_trace();

        loss_g1 = loss1(prediction, real_data_target)
        loss_g2 = loss1(fake_data, real_data)*500
        loss_g = loss_g1 + loss_g2

        loss_g.backward()

        # Update weights with gradients
        g_optimizer.step()

        # Log error
        logger.log(loss_d, loss_g, epoch, n_batch, num_batches)

        # Display Progress
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise())).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                loss_d, loss_g, prediction_real, prediction_fake
            )
        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)

Thank you so much :)

someonegirl commented 1 year ago

Excuse me, have you successfully used diffusion-gan? If successful, can you share your experience?

Sarah-2021-scu commented 5 months ago

@RisabBiswas, @someonegirl, Were you able to use Diffusion-GAN in other GAN architectures? Can you please share your experience?