junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
22.96k stars 6.31k forks source link

Blurred results using cycleGAN on depthmaps #1313

Open ghost opened 3 years ago

ghost commented 3 years ago

Hello everybody,

I want to use cycleGAN on depthmaps for domain adaptation. We are currently training the cycleGAN with two data sets of 2000 images each, but our result are blurred.

results

Trainining configuration:

image_size = 128 input_channels = 1 output_channels = 1 norm = instance_norm lr = 0.0002 beta_1 = 0.5 cycle_loss_weight = 10 batch_size = 1 9 res_blocks

Model

import torch import torch.nn as nn import torch.nn.functional as F

class Discriminator(nn.Module): def init(self): super(Discriminator, self).init()

    self.main = nn.Sequential(
        nn.Conv2d(1, 64, 4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64, 128, 4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(128, 256, 4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(256, 512, 4, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(512, 1, 4, padding=1),
    )

def forward(self, x):
    x = self.main(x)
    x = F.avg_pool2d(x, x.size()[2:])
    x = torch.flatten(x, 1)
    return x

class Generator(nn.Module): def init(self): super(Generator, self).init() self.main = nn.Sequential(

Initial convolution block

        nn.ReflectionPad2d(3),
        nn.Conv2d(1, 64, 7),
        nn.InstanceNorm2d(64),
        nn.ReLU(inplace=True),

        # Downsampling
        nn.Conv2d(64, 128, 3, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 256, 3, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.ReLU(inplace=True),

        # Residual blocks
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),
        ResidualBlock(256),

        # Upsampling
        nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(64),
        nn.ReLU(inplace=True),

        # Output layer
        nn.ReflectionPad2d(3),
        nn.Conv2d(64, 1, 7),
        nn.Tanh()
    )

def forward(self, x):
    return self.main(x)

class ResidualBlock(nn.Module): def init(self, in_channels): super(ResidualBlock, self).init()

    self.res = nn.Sequential(nn.ReflectionPad2d(1),
                             nn.Conv2d(in_channels, in_channels, 3),
                             nn.InstanceNorm2d(in_channels),
                             nn.ReLU(inplace=True),
                             nn.ReflectionPad2d(1),
                             nn.Conv2d(in_channels, in_channels, 3),
                             nn.InstanceNorm2d(in_channels))

def forward(self, x):
    return x + self.res(x)

Train

import argparse import itertools import os import random

import torch.backends.cudnn as cudnn import torch.utils.data import torchvision.transforms as transforms import torchvision.utils as vutils from PIL import Image from tqdm import tqdm import matplotlib.pylab as plt

from cyclegan_pytorch import DecayLR from cyclegan_pytorch import Discriminator from cyclegan_pytorch import Generator from cyclegan_pytorch import ImageDataset from cyclegan_pytorch import ReplayBuffer from cyclegan_pytorch import weights_init

parser = argparse.ArgumentParser( description="PyTorch implements Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks") parser.add_argument("--dataroot", type=str, default="./data", help="path to datasets. (default:./data)") parser.add_argument("--dataset", type=str, default="depthmaps_white", help="dataset name. (default:horse2zebra)" "Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, " "cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, " "iphone2dslr_flower, ae_photos, ]") parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run") parser.add_argument("--decay_epochs", type=int, default=100, help="epoch to start linearly decaying the learning rate to 0. (default:100)") parser.add_argument("-b", "--batch-size", default=1, type=int, metavar="N", help="mini-batch size (default: 1), this is the total " "batch size of all GPUs on the current node when " "using Data Parallel or Distributed Data Parallel") parser.add_argument("--lr", type=float, default=0.0002, help="learning rate. (default:0.0002)") parser.add_argument("-p", "--print-freq", default=100, type=int, metavar="N", help="print frequency. (default:100)") parser.add_argument("--cuda", action="store_true", help="Enables cuda") parser.add_argument("--netG_A2B", default="", help="path to netG_A2B (to continue training)") parser.add_argument("--netG_B2A", default="", help="path to netG_B2A (to continue training)") parser.add_argument("--netD_A", default="", help="path to netD_A (to continue training)") parser.add_argument("--netD_B", default="", help="path to netD_B (to continue training)") parser.add_argument("--image-size", type=int, default=128, help="size of the data crop (squared assumed). (default:128)") parser.add_argument("--outf", default="./outputs", help="folder to output images. (default:./outputs).") parser.add_argument("--manualSeed", type=int, help="Seed for initializing training. (default:none)")

args = parser.parse_args() print(args)

try: os.makedirs(args.outf) except OSError: pass

try: os.makedirs("weights") except OSError: pass

if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) print("Random Seed: ", args.manualSeed) random.seed(args.manualSeed) torch.manual_seed(args.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not args.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda")

Dataset

dataset = ImageDataset(root=os.path.join(args.dataroot, args.dataset), transform=transforms.Compose([ transforms.Resize(int(args.image_size * 1.12), Image.BICUBIC), transforms.RandomCrop(args.image_size),

transforms.RandomHorizontalFlip(),

                       transforms.ToTensor()]),
                       #transforms.Grayscale(1),
                       #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
                       #transforms.Normalize((0.5), (0.5))])
                   unaligned=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True)

try: os.makedirs(os.path.join(args.outf, args.dataset, "Sim")) os.makedirs(os.path.join(args.outf, args.dataset, "Real")) except OSError: pass

try: os.makedirs(os.path.join("weights", args.dataset)) except OSError: pass

device = torch.device("cuda:0" if args.cuda else "cpu")

create model

netG_A2B = Generator().to(device) netG_B2A = Generator().to(device) netD_A = Discriminator().to(device) netD_B = Discriminator().to(device)

netG_A2B.apply(weights_init) netG_B2A.apply(weights_init) netD_A.apply(weights_init) netD_B.apply(weights_init)

if args.netG_A2B != "": netG_A2B.load_state_dict(torch.load(args.netG_A2B)) if args.netG_B2A != "": netG_B2A.load_state_dict(torch.load(args.netG_B2A)) if args.netD_A != "": netD_A.load_state_dict(torch.load(args.netD_A)) if args.netD_B != "": netD_B.load_state_dict(torch.load(args.netD_B))

define loss function (adversarial_loss) and optimizer

cycle_loss = torch.nn.L1Loss().to(device) identity_loss = torch.nn.L1Loss().to(device) adversarial_loss = torch.nn.MSELoss().to(device)

Optimizers

optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=args.lr, betas=(0.5, 0.999))

lr_lambda = DecayLR(args.epochs, 0, args.decay_epochs).step lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

g_losses = [] d_losses = []

identity_losses = [] gan_losses = [] cycle_losses = []

fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer()

for epoch in range(0, args.epochs): progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) for i, data in progress_bar:

get batch size data

    real_image_A = data["Sim"].to(device)
    real_image_B = data["Real"].to(device)
    batch_size = real_image_A.size(0)

    # real data label is 1, fake data label is 0.
    real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
    fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

    ##############################################
    # (1) Update G network: Generators A2B and B2A
    ##############################################

    # Set G_A and G_B's gradients to zero
    optimizer_G.zero_grad()

    # Identity loss
    # G_B2A(A) should equal A if real A is fed
    identity_image_A = netG_B2A(real_image_A)
    loss_identity_A = identity_loss(identity_image_A, real_image_A) * 1.0
    # G_A2B(B) should equal B if real B is fed
    identity_image_B = netG_A2B(real_image_B)
    loss_identity_B = identity_loss(identity_image_B, real_image_B) * 1.0

    # GAN loss
    # GAN loss D_A(G_A(A))
    fake_image_A = netG_B2A(real_image_B)
    fake_output_A = netD_A(fake_image_A)
    loss_GAN_B2A = adversarial_loss(fake_output_A, real_label)
    # GAN loss D_B(G_B(B))
    fake_image_B = netG_A2B(real_image_A)
    fake_output_B = netD_B(fake_image_B)
    loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

    # Cycle loss
    recovered_image_A = netG_B2A(fake_image_B)
    loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

    recovered_image_B = netG_A2B(fake_image_A)
    loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

    # Combined loss and calculate gradients
    errG = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
    #errG = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

    # Calculate gradients for G_A and G_B
    errG.backward()
    # Update G_A and G_B's weights
    optimizer_G.step()

    ##############################################
    # (2) Update D network: Discriminator A
    ##############################################

    # Set D_A gradients to zero
    optimizer_D_A.zero_grad()

    # Real A image loss
    real_output_A = netD_A(real_image_A)
    errD_real_A = adversarial_loss(real_output_A, real_label)

    # Fake A image loss
    fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
    fake_output_A = netD_A(fake_image_A.detach())
    errD_fake_A = adversarial_loss(fake_output_A, fake_label)

    # Combined loss and calculate gradients
    errD_A = (errD_real_A + errD_fake_A) / 2

    # Calculate gradients for D_A
    errD_A.backward()
    # Update D_A weights
    optimizer_D_A.step()

    ##############################################
    # (3) Update D network: Discriminator B
    ##############################################

    # Set D_B gradients to zero
    optimizer_D_B.zero_grad()

    # Real B image loss
    real_output_B = netD_B(real_image_B)
    errD_real_B = adversarial_loss(real_output_B, real_label)

    # Fake B image loss
    fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
    fake_output_B = netD_B(fake_image_B.detach())
    errD_fake_B = adversarial_loss(fake_output_B, fake_label)

    # Combined loss and calculate gradients
    errD_B = (errD_real_B + errD_fake_B) / 2

    # Calculate gradients for D_B
    errD_B.backward()
    # Update D_B weights
    optimizer_D_B.step()

    progress_bar.set_description(
        f"[{epoch}/{args.epochs - 1}][{i}/{len(dataloader) - 1}] "
        f"Loss_D: {(errD_A + errD_B).item():.4f} "
        f"Loss_G: {errG.item():.4f} "
        f"Loss_G_identity: {(loss_identity_A + loss_identity_B).item():.4f} "
        f"loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
        f"loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}")

    if i % args.print_freq == 0:
        vutils.save_image(real_image_A,
                          f"{args.outf}/{args.dataset}/Sim/real_samples.png",
                          normalize=True)
        vutils.save_image(real_image_B,
                          f"{args.outf}/{args.dataset}/Real/real_samples.png",
                          normalize=True)
        #print('Real Image A Shape:', real_image_A.shape)
        #print('Real Image B Shape:', real_image_B.shape)

        fake_image_A = 0.5 * (netG_B2A(real_image_B).data + 1.0)
        fake_image_B = 0.5 * (netG_A2B(real_image_A).data + 1.0)

        #print('Fake Image A Shape:',fake_image_A.shape)
        #print('Fake Image B Shape:',fake_image_B.shape)

        vutils.save_image(fake_image_A.detach(),
                          f"{args.outf}/{args.dataset}/Sim/fake_samples.png",
                          normalize=True)
        vutils.save_image(fake_image_B.detach(),
                          f"{args.outf}/{args.dataset}/Real/fake_samples.png",
                          normalize=True)

        transformed_image_A = 0.5 * (netG_B2A(fake_image_B).data + 1.0)
        transformed_image_B = 0.5 * (netG_A2B(fake_image_A).data + 1.0)

        #print('Transformed Image A Shape:',transformed_image_A.shape)
        #print('Transformed Image B Shape:',transformed_image_B.shape)

        vutils.save_image(transformed_image_A.detach(),
                          f"{args.outf}/{args.dataset}/Sim/transformed_samples.png",
                          normalize=True)
        vutils.save_image(transformed_image_B.detach(),
                          f"{args.outf}/{args.dataset}/Real/transformed_samples.png",
                          normalize=True)

        # merge results
        new_img = Image.new('RGB', (args.image_size*3,args.image_size*2),'white')
        real_sample_A = Image.open(f"{args.outf}/{args.dataset}/Sim/real_samples.png")
        fake_sample_B = Image.open(f"{args.outf}/{args.dataset}/Real/fake_samples.png")
        transformed_image_A = Image.open(f"{args.outf}/{args.dataset}/Sim/transformed_samples.png")
        transformed_image_B = Image.open(f"{args.outf}/{args.dataset}/Real/transformed_samples.png")
        real_sample_B = Image.open(f"{args.outf}/{args.dataset}/Real/real_samples.png")
        fake_sample_A = Image.open(f"{args.outf}/{args.dataset}/Sim/fake_samples.png")
        new_img.paste(real_sample_A,(0,0,args.image_size,args.image_size))
        new_img.paste(fake_sample_B,(args.image_size,0,args.image_size*2,args.image_size))
        new_img.paste(transformed_image_A,(args.image_size*2,0,args.image_size*3,args.image_size))
        new_img.paste(real_sample_B,(0,args.image_size,args.image_size,args.image_size*2))
        new_img.paste(fake_sample_A,(args.image_size,args.image_size,args.image_size*2,args.image_size*2))
        new_img.paste(transformed_image_B,(args.image_size*2,args.image_size,args.image_size*3,args.image_size*2))
        new_img.save(f"{args.outf}/{args.dataset}/results_at_epoch_{epoch}_idx_{i}.png")

# do check pointing
torch.save(netG_A2B.state_dict(), f"weights/{args.dataset}/netG_A2B_epoch_{epoch}.pth")
torch.save(netG_B2A.state_dict(), f"weights/{args.dataset}/netG_B2A_epoch_{epoch}.pth")
torch.save(netD_A.state_dict(), f"weights/{args.dataset}/netD_A_epoch_{epoch}.pth")
torch.save(netD_B.state_dict(), f"weights/{args.dataset}/netD_B_epoch_{epoch}.pth")

# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()

save last check pointing

torch.save(netG_A2B.state_dict(), f"weights/{args.dataset}/netG_A2B.pth") torch.save(netG_B2A.state_dict(), f"weights/{args.dataset}/netG_B2A.pth") torch.save(netD_A.state_dict(), f"weights/{args.dataset}/netD_A.pth") torch.save(netD_B.state_dict(), f"weights/{args.dataset}/netD_B.pth")

junyanz commented 2 years ago

It seems that you are using the custom code. Does the original code work for your datasets?