UBCDingXin / improved_CcGAN

Continuous Conditional Generative Adversarial Networks (CcGAN)
https://arxiv.org/abs/2011.07466
MIT License
115 stars 34 forks source link

How to use #12

Open Foundsheep opened 6 months ago

Foundsheep commented 6 months ago

Hi there,

Thanks for your work in generative AI. I'm currently trying to implement CcGAN and just wondering, is there anyway to train the model easily?

I can see there are some python files which receive arguments through shell script files, but they lack requirements.txt and need some configuration to run, as far as I can guess.

It'd be much appreciated if you could provide a way to use your model and see how it works with other custom datasets.

Thanks in advance.

Foundsheep commented 6 months ago

By the way, my local uses Windows 11 OS

ashamy97 commented 4 months ago

Hey, I am stuck on the same issue. Did you figure it out by any chance?

Foundsheep commented 4 months ago

I managed to run the shell script and used my custom datasets.

What is the issue you're stuck at? @ashamy97

ashamy97 commented 4 months ago

I am just stuck on what to do once I cloned the repository. I am using Jupyter notebook. Can you tell me what are the steps that you followed? SO I have my own custom dataset that I want to train the CcGAN on and instead of one conditional input, I have two. So I am not sure what do I need to change and how to run it. Thank you so much for your help.

Foundsheep commented 4 months ago
  1. Understand the structure of the repository
    • Cell-200, RC-49, SteeringAngle... these upper level folders have the similar structure underneath it, by which the authors wanted to structure the code by image size and models used(if you read the paper, you will understand why they have structured the code like this. I guess, they just based the structuring strategy on their experiments
    • So, 64x64 or 128x128 mean the image size to be used in the training, and CcGAN, CcGAN-improved and cGAN-concat are the models to be used in the training also.
  2. Find the run_train.sh in the script folder of the image dataset and the model you want to use in the training
    • For example, I used RC-49_128x128.h5 dataset(you can download easily through the link in README.md), and CcGAN-improved, so I went for the run_train.sh script under .../RC-49/RC-49_128x128/scripts folder
    • The reason why you need to choose the dataset and model before the training is because the preprocessing code are customised for specific dataset and the model for the training
  3. Prepare the dataset, if you want to use your own
  4. Modify the run_train.sh
    • Especially, you might need to modify ROOT_PATH and DATA_PATH to make the shell script run
    • Also, python main.py... this bit actually doesn't work as it is now, since main.py is not on the same level, but on the level one upper than as it is now(you will see what I mean, if you run that script by facing file not found error)
    • Comment unnecessary parts in run_train.sh to enhance the training speed
      • For example, I did like below
        
        ## Path
        ROOT_PATH="../../"
        DATA_PATH="/c/Users/msi/Desktop/workspace/001_practice/improved_CcGAN/RC-49/RC-49_128x128/CcGAN-improved/scripts/datasets"
        EVAL_PATH="/c/Users/msi/Desktop/workspace/001_practice/improved_CcGAN/RC-49/RC-49_128x128/CcGAN-improved/output/eval_models"

SEED=2020 NUM_WORKERS=0 MIN_LABEL=0 MAX_LABEL=360

추가

MIN_LABEL_SCALE=0.0 MAX_LABEL_SCALE=1.0

IMG_SIZE=128 MAX_N_IMG_PER_LABEL=25 MAX_N_IMG_PER_LABEL_AFTER_REPLICA=0

NITERS=15000 BATCH_SIZE_G=36 BATCH_SIZE_D=36 NUM_D_STEPS=2 SIGMA=-1.0 KAPPA=-2.0 LR_G=1e-4 LR_D=1e-4 GAN_ARCH="SAGAN" LOSS_TYPE="hinge"

NUM_EVAL_LABELS=-1 NFAKE_PER_LABEL=200 SAMP_BATCH_SIZE=1000 FID_RADIUS=0 FID_NUM_CENTERS=-1

python pretrain_AE.py \

--root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \

--dim_bottleneck 512 --epochs 200 --resume_epoch 0 \

--batch_size_train 256 --batch_size_valid 10 \

--base_lr 1e-3 --lr_decay_epochs 50 --lr_decay_factor 0.1 \

--lambda_sparsity 0 --weight_dacay 1e-4 \

--img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \

2>&1 | tee output_AE.txt

python pretrain_CNN_class.py \

--root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \

--CNN ResNet34_class \

--epochs 200 --batch_size_train 256 --batch_size_valid 10 \

--base_lr 0.01 --weight_dacay 1e-4 \

--img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \

2>&1 | tee output_CNN_class.txt

python pretrain_CNN_regre.py \

--root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \

--CNN ResNet34_regre \

--epochs 200 --batch_size_train 256 --batch_size_valid 10 \

--base_lr 0.01 --weight_dacay 1e-4 \

--img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \

2>&1 | tee output_CNN_regre.txt

GAN="CcGAN" DIM_GAN=256 DIM_EMBED=128 resume_niters_gan=2800 python ../main.py \ --root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \ --min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \ --min_label_scale $MIN_LABEL_SCALE --max_label $MAX_LABEL_SCALE \ --max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \ --GAN $GAN --GAN_arch $GAN_ARCH --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \ --save_niters_freq 200 --visualize_freq 200 \ --batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \ --lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN --dim_embed $DIM_EMBED \ --kernel_sigma $SIGMA --threshold_type soft --kappa $KAPPA \ --gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \ --visualize_fake_images \ --comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \ --num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \ --dump_fake_for_NIQE \ 2>&1 | tee output_CcGAN_30K.txt

GAN="cGAN"

DIM_GAN=128

resume_niters_gan=0

python main.py \

--root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \

--min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \

--max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \

--GAN $GAN --GAN_arch $GAN_ARCH --cGAN_num_classes 150 --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \

--save_niters_freq 2000 --visualize_freq 1000 \

--batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \

--lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN \

--gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \

--visualize_fake_images \

--comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \

--num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \

--dump_fake_for_NIQE \

2>&1 | tee output_cGAN_150classes_30K.txt

GAN="cGAN-concat"

DIM_GAN=128

resume_niters_gan=0

python main.py \

--root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \

--min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \

--max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \

--GAN $GAN --GAN_arch $GAN_ARCH --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \

--save_niters_freq 2000 --visualize_freq 1000 \

--batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \

--lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN \

--gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \

--visualize_fake_images \

--comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \

--num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \

--dump_fake_for_NIQE \

2>&1 | tee output_cGAN-concat_30K.txt



7. Modify the code, if needed, especially if you use your own custom dataset
   - I slightly changed the input condition, so I had to modify `main.py` and `train_ccgan.py`
8. See the generated result, which will be generated in `output` folder
9. If you're using Jupyter notebook, then, you might need to set all the things done as I mentioned above, and invoke the shell script in the notebook cell like `!python run_train.sh`

Hope this would help you
ashamy97 commented 4 months ago

thank you so much! This is very helpful. So the dataset is loaded from the main.py correct? Specifically lines 70-78? And just curious do you know how I would pass in two conditional inputs instead of one? Do I need to pass the second input into the CNN to create an embedding and all of that (just like the first conditional input)?

Foundsheep commented 4 months ago

ResNet_embed.py

'''
ResNet-based model to map an image from pixel space to a features space.
Need to be pretrained on the dataset.

if isometric_map = True, there is an extra step (elf.classifier_1 = nn.Linear(512, 32*32*3)) to increase the dimension of the feature map from 512 to 32*32*3. This selection is for desity-ratio estimation in feature space.

codes are based on
@article{
zhang2018mixup,
title={mixup: Beyond Empirical Risk Minimization},
author={Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz},
journal={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=r1Ddp1-Rb},
}
'''

import torch
import torch.nn as nn
import torch.nn.functional as F

NC = 3
IMG_SIZE = 128
DIM_EMBED = 128

NUM_CONDITIONS = 2

#------------------------------------------------------------------------------
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet_embed(nn.Module):
    def __init__(self, block, num_blocks, nc=NC, dim_embed=DIM_EMBED):
        super(ResNet_embed, self).__init__()
        self.in_planes = 64

        self.main = nn.Sequential(
            nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False),  # h=h
            # nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1, bias=False),  # h=h/2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2), #h=h/2 64
            # self._make_layer(block, 64, num_blocks[0], stride=1),  # h=h
            self._make_layer(block, 64, num_blocks[0], stride=2),  # h=h/2 32
            self._make_layer(block, 128, num_blocks[1], stride=2), # h=h/2 16
            self._make_layer(block, 256, num_blocks[2], stride=2), # h=h/2 8
            self._make_layer(block, 512, num_blocks[3], stride=2), # h=h/2 4
            # nn.AvgPool2d(kernel_size=4)
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.x2h_res = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, dim_embed),
            nn.BatchNorm1d(dim_embed),
            nn.ReLU(),
        )

        self.h2y = nn.Sequential(
            nn.Linear(dim_embed, NUM_CONDITIONS),
            nn.ReLU()
        )

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):

        features = self.main(x)
        features = features.view(features.size(0), -1)
        features = self.x2h_res(features)
        out = self.h2y(features)

        return out, features

def ResNet18_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(BasicBlock, [2,2,2,2], dim_embed=dim_embed)

def ResNet34_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(BasicBlock, [3,4,6,3], dim_embed=dim_embed)

def ResNet50_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(Bottleneck, [3,4,6,3], dim_embed=dim_embed)

#------------------------------------------------------------------------------
# map labels to the embedding space
class model_y2h(nn.Module):
    def __init__(self, dim_embed=DIM_EMBED):
        super(model_y2h, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(NUM_CONDITIONS, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            nn.ReLU()
        )

    def forward(self, y):
        # y = y.view(-1, 1) +1e-8
        y = y + 1e-8
        # y = torch.exp(y.view(-1, 1))
        return self.main(y)

if __name__ == "__main__":
    net = ResNet34_embed(dim_embed=128).cuda()
    x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
    out, features = net(x)
    print(out.size())
    print(features.size())

    net_y2h = model_y2h().cuda()
    y_hat = net_y2h(out)
    print(f"{y_hat.size() = }")

train_net_for_label_embed.py


import torch
import torch.nn as nn
from torchvision.utils import save_image
import numpy as np
import os
import timeit
from PIL import Image

NUM_CONDITIONS = 2

#-------------------------------------------------------------
def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):

    ''' learning rate decay '''
    def adjust_learning_rate_1(optimizer, epoch):
        """decrease the learning rate """
        lr = lr_base

        num_decays = len(lr_decay_epochs)
        for decay_i in range(num_decays):
            if epoch >= lr_decay_epochs[decay_i]:
                lr = lr * lr_decay_factor
            #end if epoch
        #end for decay_i
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    net = net.cuda()
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)

    # resume training; load checkpoint
    if path_to_ckpt is not None and resume_epoch>0:
        save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
        checkpoint = torch.load(save_file)
        net.load_state_dict(checkpoint['net_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    start_tmp = timeit.default_timer()
    for epoch in range(resume_epoch, epochs):
        net.train()
        train_loss = 0
        adjust_learning_rate_1(optimizer, epoch)
        for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):

            # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)

            batch_train_images = batch_train_images.type(torch.float).cuda()
            batch_train_labels = batch_train_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()

            #Forward pass
            outputs, _ = net(batch_train_images)
            loss = criterion(outputs, batch_train_labels)

            #backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu().item()
        #end for batch_idx
        train_loss = train_loss / len(trainloader)

        if testloader is None:
            print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
        else:
            net.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
            with torch.no_grad():
                test_loss = 0
                for batch_test_images, batch_test_labels in testloader:
                    batch_test_images = batch_test_images.type(torch.float).cuda()
                    batch_test_labels = batch_test_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()
                    outputs,_ = net(batch_test_images)
                    loss = criterion(outputs, batch_test_labels)
                    test_loss += loss.cpu().item()
                test_loss = test_loss/len(testloader)

                print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))

        #save checkpoint
        if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
            save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save({
                    'epoch': epoch,
                    'net_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'rng_state': torch.get_rng_state()
            }, save_file)
    #end for epoch

    return net

###################################################################################
class label_dataset(torch.utils.data.Dataset):
    def __init__(self, labels):
        super(label_dataset, self).__init__()

        self.labels = labels
        self.n_samples = len(self.labels)

    def __getitem__(self, index):

        y = self.labels[index]
        return y

    def __len__(self):
        return self.n_samples

def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
    '''
    unique_labels_norm: an array of normalized unique labels
    '''

    ''' learning rate decay '''
    def adjust_learning_rate_2(optimizer, epoch):
        """decrease the learning rate """
        lr = lr_base

        num_decays = len(lr_decay_epochs)
        for decay_i in range(num_decays):
            if epoch >= lr_decay_epochs[decay_i]:
                lr = lr * lr_decay_factor
            #end if epoch
        #end for decay_i
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # unique_labels_norm.shape == (B, NUM_CONDITIONS)로 가정
    assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
    trainset = label_dataset(unique_labels_norm)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    net_embed.eval()
    net_h2y=net_embed.module.h2y #convert embedding labels to original labels
    optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)

    start_tmp = timeit.default_timer()
    for epoch in range(epochs):
        net_y2h.train()
        train_loss = 0
        adjust_learning_rate_2(optimizer_y2h, epoch)
        for _, batch_labels in enumerate(trainloader):

            batch_labels = batch_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()

            # generate noises which will be added to labels
            batch_size_curr = len(batch_labels)
            batch_gamma = np.random.normal(0, 0.2, (batch_size_curr, NUM_CONDITIONS))
            batch_gamma = torch.from_numpy(batch_gamma).view(-1,NUM_CONDITIONS).type(torch.float).cuda()

            # add noise to labels
            batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)

            #Forward pass
            batch_hiddens_noise = net_y2h(batch_labels_noise)
            batch_rec_labels_noise = net_h2y(batch_hiddens_noise)

            loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)

            #backward pass
            optimizer_y2h.zero_grad()
            loss.backward()
            optimizer_y2h.step()

            train_loss += loss.cpu().item()
        #end for batch_idx
        train_loss = train_loss / len(trainloader)

        print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
    #end for epoch

    return net_y2h

train_ccgan.py

import torch
import numpy as np
import os
import timeit
from PIL import Image
from torchvision.utils import save_image
import torch.cuda as cutorch

from utils import SimpleProgressBar, IMGs_dataset
from opts import parse_opts
from DiffAugment_pytorch import DiffAugment

NUM_CONDITIONS = 2

''' Settings '''
args = parse_opts()

# some parameters in opts
gan_arch = args.GAN_arch
loss_type = args.loss_type_gan
niters = args.niters_gan
resume_niters = args.resume_niters_gan
dim_gan = args.dim_gan
lr_g = args.lr_g_gan
lr_d = args.lr_d_gan
save_niters_freq = args.save_niters_freq
batch_size_disc = args.batch_size_disc
batch_size_gene = args.batch_size_gene
# batch_size_max = max(batch_size_disc, batch_size_gene)
num_D_steps = args.num_D_steps

visualize_freq = args.visualize_freq

num_workers = args.num_workers

threshold_type = args.threshold_type
nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold

num_channels = args.num_channels
img_size = args.img_size
max_label = args.max_label

use_DiffAugment = args.gan_DiffAugment
policy = args.gan_DiffAugment_policy

## normalize images
def normalize_images(batch_images):
    batch_images = batch_images/255.0
    batch_images = (batch_images - 0.5)/0.5
    return batch_images

def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):

    '''
    Note that train_images are not normalized to [-1,1]
    '''

    netG = netG.cuda()
    netD = netD.cuda()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))

    if save_models_folder is not None and resume_niters>0:
        save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
        print(f"Got model from {save_file} successfully")
    #end if

    #################
    # --- unique_train_labels = np.sort(np.array(list(set(train_labels))))
    unique_train_labels_1 = np.sort(np.array(list(set(train_labels[:, 0]))))
    unique_train_labels_2 = np.sort(np.array(list(set(train_labels[:, 1]))))
    # print(f"------------{unique_train_labels_1 = }")
    # print(f"------------{unique_train_labels_2 = }")

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row=10; n_col = n_row
    z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
    # --- start_label = np.quantile(train_labels, 0.05)
    # --- end_label = np.quantile(train_labels, 0.95)
    # --- selected_labels = np.linspace(start_label, end_label, num=n_row)
    start_label_1 = np.quantile(train_labels[:, 0], 0.05)
    end_label_1 = np.quantile(train_labels[:, 0], 0.95)
    selected_labels_1 = np.linspace(start_label_1, end_label_1, num=n_row)

    start_label_2 = np.quantile(train_labels[:, 1], 0.05)
    end_label_2 = np.quantile(train_labels[:, 1], 0.95)
    selected_labels_2 = np.linspace(start_label_2, end_label_2, num=n_col)

    # --- y_fixed = np.zeros(n_row*n_col)
    # --- for i in range(n_row):
    # ---     curr_label = selected_labels[i]
    # ---     for j in range(n_col):
    # ---         y_fixed[i*n_col+j] = curr_label

    y_fixed = np.zeros((n_row * n_col, NUM_CONDITIONS))
    for i in range(n_row):
        curr_label_1 = selected_labels_1[i]
        for j in range(n_col):
            curr_label_2 = selected_labels_2[j]
            y_fixed[i*n_col+j, 0] = curr_label_1
            y_fixed[i*n_col+j, 1] = curr_label_2

    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,NUM_CONDITIONS).cuda()

    # print(f"{y_fixed = }")

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):

        '''  Train Discriminator   '''
        for _ in range(num_D_steps):

            # ## randomly draw batch_size_disc y's from unique_train_labels
            # batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
            # ## add Gaussian noise; we estimate image distribution conditional on these labels
            # batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
            # batch_target_labels = batch_target_labels_in_dataset + batch_epsilons

            # ## find index of real images with labels in the vicinity of batch_target_labels
            # ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
            # batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
            # batch_fake_labels = np.zeros(batch_size_disc)

            # for j in range(batch_size_disc):
            #     ## index for real images
            #     if threshold_type == "hard":
            #         indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
            #     else:
            #         # reverse the weight function for SVDL
            #         indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]

            #     ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
            #     while len(indx_real_in_vicinity)<1:
            #         batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
            #         batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
            #         if clip_label:
            #             batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
            #         ## index for real images
            #         if threshold_type == "hard":
            #             indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
            #         else:
            #             # reverse the weight function for SVDL
            #             indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
            #     #end while len(indx_real_in_vicinity)<1

            #     assert len(indx_real_in_vicinity)>=1

            #     batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]

            #     ## labels for fake images generation
            #     if threshold_type == "hard":
            #         lb = batch_target_labels[j] - kappa
            #         ub = batch_target_labels[j] + kappa
            #     else:
            #         lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
            #         ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
            #     lb = max(0.0, lb); ub = min(ub, 1.0)
            #     assert lb<=ub
            #     assert lb>=0 and ub>=0
            #     assert lb<=1 and ub<=1
            #     batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
            # #end for j

# ----------------------------------------------------------------------------------------------------------------

            ## randomly draw batch_size_disc y's from unique_train_labels
            batch_target_labels_in_dataset_1 = np.random.choice(unique_train_labels_1, size=batch_size_disc, replace=True)
            batch_target_labels_in_dataset_2 = np.random.choice(unique_train_labels_2, size=batch_size_disc, replace=True)

            ## add Gaussian noise; we estimate image distribution conditional on these labels
            batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
            batch_target_labels_1 = batch_target_labels_in_dataset_1 + batch_epsilons
            batch_target_labels_2 = batch_target_labels_in_dataset_2 + batch_epsilons

            ## find index of real images with labels in the vicinity of batch_target_labels
            ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
            batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
            batch_fake_labels = np.zeros((batch_size_disc, NUM_CONDITIONS))

            for j in range(batch_size_disc):
                ## index for real images
                if threshold_type == "hard":
                    indx_real_in_vicinity_1 = np.where(np.abs(train_labels[:, 0]-batch_target_labels_1[j])<= kappa)[0]
                    indx_real_in_vicinity_2 = np.where(np.abs(train_labels[:, 1]-batch_target_labels_2[j])<= kappa)[0]
                else:
                    # reverse the weight function for SVDL
                    indx_real_in_vicinity_1 = np.where((train_labels[:, 0]-batch_target_labels_1[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
                    indx_real_in_vicinity_2 = np.where((train_labels[:, 1]-batch_target_labels_2[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]

                indx_real_in_vicinity = np.array(list(set.intersection(set(indx_real_in_vicinity_1), set(indx_real_in_vicinity_2))))

                ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
                while len(indx_real_in_vicinity)<1:
                    batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
                    batch_target_labels_1[j] = batch_target_labels_in_dataset_1[j] + batch_epsilons_j
                    batch_target_labels_2[j] = batch_target_labels_in_dataset_2[j] + batch_epsilons_j
                    if clip_label:
                        batch_target_labels_1 = np.clip(batch_target_labels_1, 0.0, 1.0)
                        batch_target_labels_2 = np.clip(batch_target_labels_2, 0.0, 1.0)
                    ## index for real images
                    # if threshold_type == "hard":
                    #     indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
                    # else:
                    #     # reverse the weight function for SVDL
                    #     indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
                    if threshold_type == "hard":
                        indx_real_in_vicinity_1 = np.where(np.abs(train_labels[:, 0]-batch_target_labels_1[j])<= kappa)[0]
                        indx_real_in_vicinity_2 = np.where(np.abs(train_labels[:, 1]-batch_target_labels_2[j])<= kappa)[0]
                    else:
                        # reverse the weight function for SVDL
                        indx_real_in_vicinity_1 = np.where((train_labels[:, 0]-batch_target_labels_1[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
                        indx_real_in_vicinity_2 = np.where((train_labels[:, 1]-batch_target_labels_2[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]

                    indx_real_in_vicinity = np.array(list(set.intersection(set(indx_real_in_vicinity_1), set(indx_real_in_vicinity_2))))

                #end while len(indx_real_in_vicinity)<1

                assert len(indx_real_in_vicinity)>=1

                batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]

                ## labels for fake images generation
                if threshold_type == "hard":
                    lb_1 = batch_target_labels_1[j] - kappa
                    ub_1 = batch_target_labels_1[j] + kappa

                    lb_2 = batch_target_labels_2[j] - kappa
                    ub_2 = batch_target_labels_2[j] + kappa

                else:
                    lb_1 = batch_target_labels_1[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
                    ub_1 = batch_target_labels_1[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)

                    lb_2 = batch_target_labels_2[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
                    ub_2 = batch_target_labels_2[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
                lb_1 = max(0.0, lb_1); ub_1 = min(ub_1, 1.0)
                lb_2 = max(0.0, lb_2); ub_2 = min(ub_2, 1.0)

                assert lb_1<=ub_1
                assert lb_2<=ub_2
                assert lb_1>=0 and ub_1>=0
                assert lb_1<=1 and ub_1<=1
                assert lb_2>=0 and ub_2>=0
                assert lb_2<=1 and ub_2<=1

                batch_fake_labels[j] = np.array(np.random.uniform(lb_1, ub_1, size=1)[0], np.random.uniform(lb_2, ub_2, size=1)[0])

            batch_target_labels = np.stack([batch_target_labels_1, batch_target_labels_2], axis=1)
            #end for j

# ----------------------------------------------------------------------------------------------------------------
            # print(f"===1 {batch_target_labels.shape = }")
            # print(f"===1 {np.min(batch_target_labels) = }, {np.max(batch_target_labels) = }")
            # print(f"===2 {batch_fake_labels.shape = }")
            # print(f"===2 {np.min(batch_fake_labels) = }, {np.max(batch_fake_labels) = }")

            ## draw real image/label batch from the training set
            batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
            batch_real_images = batch_real_images.type(torch.float).cuda()
            batch_real_labels = train_labels[batch_real_indx]
            batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()

            # print(f"===3 {batch_real_labels.shape = }")
            # print(f"===3 {torch.min(batch_real_labels) = }, {torch.max(batch_real_labels) = }")

            # print(f"===4 {batch_real_images.shape = }")
            # print(f"===4 {torch.min(batch_real_images) = }, {torch.max(batch_real_images) = }")

            ## generate the fake image batch
            batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
            z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
            batch_fake_images = netG(z, net_y2h(batch_fake_labels))

            # print(f"===5 {batch_fake_images.shape = }")
            # print(f"===5 {torch.min(batch_fake_images) = }, {torch.max(batch_fake_images) = }")

            ## target labels on gpu
            batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()

            ## weight vector
            if threshold_type == "soft":
                real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
                fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
            else:
                real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
                fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
            #end if threshold type

            # forward pass
            if use_DiffAugment:
                real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
                fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
            else:
                real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
                fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))

            # print(f"===6 {real_dis_out.shape = }")
            # print(f"===6 {torch.min(real_dis_out) = }, {torch.max(real_dis_out) = }")
            # print(f"===7 {real_weights.shape = }")
            # print(f"===7 {torch.min(real_weights) = }, {torch.max(real_weights) = }")

            if loss_type == "vanilla":
                real_dis_out = torch.nn.Sigmoid()(real_dis_out)
                fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
                d_loss_real = - torch.log(real_dis_out+1e-20)
                d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
            elif loss_type == "hinge":
                d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
                d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
            else:
                raise ValueError('Not supported loss type!!!')

            # TODO: 추가했으나, real_weights의 dimension을 줄이는 게 나을지 netD의 결과물 dimension을 늘리는 게 좋을지 모르겠음
            real_weights = torch.mean(real_weights, axis=1)
            fake_weights = torch.mean(fake_weights, axis=1)
            d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

        #end for step_D_index

        '''  Train Generator   '''
        netG.train()

        # generate fake images
        ## randomly draw batch_size_gene y's from unique_train_labels
        batch_target_labels_in_dataset_1 = np.random.choice(unique_train_labels_1, size=batch_size_gene, replace=True)
        batch_target_labels_in_dataset_2 = np.random.choice(unique_train_labels_2, size=batch_size_gene, replace=True)

        ## add Gaussian noise; we estimate image distribution conditional on these labels
        batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
        batch_target_labels_1 = batch_target_labels_in_dataset_1 + batch_epsilons
        batch_target_labels_2 = batch_target_labels_in_dataset_2 + batch_epsilons

        batch_target_labels = np.stack([batch_target_labels_1, batch_target_labels_2], axis=1)

        batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()

        z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
        batch_fake_images = netG(z, net_y2h(batch_target_labels))

        # loss
        if use_DiffAugment:
            dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
        else:
            dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
        if loss_type == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = - torch.mean(torch.log(dis_out+1e-20))
        elif loss_type == "hinge":
            g_loss = - dis_out.mean()

        # backward
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # print loss
        if (niter+1) % 20 == 0:
            print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))

        if (niter+1) % visualize_freq == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, net_y2h(y_fixed))
                gen_imgs = gen_imgs.detach().cpu()
                save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)

        if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
            save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save({
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
            }, save_file)
    #end for niter
    return netG, netD

def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
    '''
    netG: pretrained generator network
    labels: float. normalized labels.
    '''

    nfake = len(labels)
    if batch_size>nfake:
        batch_size=nfake

    fake_images = []
    fake_labels = np.concatenate((labels, labels[0:batch_size]))
    netG=netG.cuda()
    netG.eval()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        n_img_got = 0
        while n_img_got < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
            batch_fake_images = netG(z, net_y2h(y))
            if denorm: #denorm imgs to save memory
                assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
                batch_fake_images = batch_fake_images*0.5+0.5
                batch_fake_images = batch_fake_images*255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
                # assert batch_fake_images.max().item()>1
            fake_images.append(batch_fake_images.cpu())
            n_img_got += batch_size
            if verbose:
                pb.update(min(float(n_img_got)/nfake, 1)*100)
        ##end while

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]
    fake_labels = fake_labels[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, fake_labels

if __name__ == "__main__":
    from models import CcGAN_SAGAN_Generator, CcGAN_SAGAN_Discriminator
    from models.ResNet_embed import model_y2h
    B = 500
    images_train = np.random.normal(size=(B, 3, 128, 128))
    labels_train = np.random.normal(size=(B, 2))
    kernel_sigma = -1.0
    if kernel_sigma < 0:
        kernel_sigma = 1.06 * np.std(labels_train) * (len(labels_train))**(-1/5)
    kappa = -1
    if kappa < 0:
        unique_labels_norm = np.unique(labels_train[:, 0])
        n_unique = len(unique_labels_norm)

        diff_list = []
        for i in range(1, n_unique):
            diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
        kappa_base = np.abs(kappa) * np.max(np.array(diff_list))

        # threshold_type 관련 분기가 있지만 여기서는 soft로 진행
        kappa = 1 / kappa_base ** 2

    dim_embed = 128
    netG = CcGAN_SAGAN_Generator(dim_z=dim_gan, dim_embed=128)
    netD = CcGAN_SAGAN_Discriminator(dim_embed=dim_embed)
    net_y2h = model_y2h(dim_embed=dim_embed)
    save_image_in_train_folder = "."
    save_models_folder = "."

    # images_train = torch.from_numpy(images_train).type(torch.float)
    # labels_train = torch.from_numpy(labels_train).type(torch.float)

    netG, netD = train_ccgan(kernel_sigma,
                             kappa,
                             images_train,
                             labels_train,
                             netG,
                             netD,
                             net_y2h,
                             save_image_in_train_folder,
                             save_models_folder)

main.py

print("\n===================================================================================================")

import argparse
import copy
import gc
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib as mpl
import h5py
import os
import random
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision.utils import save_image
import timeit
from PIL import Image
import sys

### import my stuffs ###
from opts import parse_opts
args = parse_opts()
wd = args.root_path
os.chdir(wd)
from utils import *
from models import *
from train_cgan import train_cgan, sample_cgan_given_labels
from train_cgan_concat import train_cgan_concat, sample_cgan_concat_given_labels
from train_ccgan import train_ccgan, sample_ccgan_given_labels
from train_net_for_label_embed import train_net_embed, train_net_y2h
from eval_metrics import cal_FID, cal_labelscore

#######################################################################################
'''                                   Settings                                      '''
#######################################################################################
#-------------------------------
# seeds
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
np.random.seed(args.seed)

NUM_CONDITIONS = 2

#-------------------------------
# output folders
path_to_output = os.path.join(wd, "output/output_{}_arch_{}".format(args.GAN, args.GAN_arch))

print(f"{path_to_output = }")

os.makedirs(path_to_output, exist_ok=True)
save_models_folder = os.path.join(path_to_output, 'saved_models')
os.makedirs(save_models_folder, exist_ok=True)
save_images_folder = os.path.join(path_to_output, 'saved_images')
os.makedirs(save_images_folder, exist_ok=True)

path_to_embed_models = os.path.join(wd, 'output/embed_models')
os.makedirs(path_to_embed_models, exist_ok=True)

#-------------------------------
# Embedding
base_lr_x2y = 0.01
base_lr_y2h = 0.01

#######################################################################################
'''                                    Data loader                                 '''
#######################################################################################
# data loader
# data_filename = args.data_path + '/RC-49_{}x{}.h5'.format(args.img_size, args.img_size)

# 신규 데이터셋
data_filename = args.data_path + '/RC-49_{}x{}_downscale.h5'.format(args.img_size, args.img_size)

hf = h5py.File(data_filename, 'r')
labels_all = hf['labels'][:]
labels_all = labels_all.astype(float)
images_all = hf['images'][:]
indx_train = hf['indx_train'][:]
hf.close()
print("\n RC-49 dataset shape: {}x{}x{}x{}".format(images_all.shape[0], images_all.shape[1], images_all.shape[2], images_all.shape[3]))

# data split
if args.data_split == "train":
    images_train = images_all[indx_train]
    labels_train_raw = labels_all[indx_train]
else:
    images_train = copy.deepcopy(images_all)
    labels_train_raw = copy.deepcopy(labels_all)

# only take images with label in (q1, q2)
q1 = args.min_label
q2 = args.max_label

# # 여러 개의 조건일 경우
if labels_train_raw.shape[-1] == NUM_CONDITIONS:
    scale_q1 = args.min_label_scale
    scale_q2 = args.max_label_scale
    indx_1 = np.where((labels_train_raw[:, 0]>q1)*(labels_train_raw[:, 0]<q2)==True)[0]
    indx_2 = np.where((labels_train_raw[:, 1]>scale_q1)*(labels_train_raw[:, 1]<scale_q2)==True)[0]
    indx = np.array(list(set.intersection(set(indx_1), set(indx_2))))
else:
    indx = np.where((labels_train_raw>q1)*(labels_train_raw<q2)==True)[0]

labels_train_raw = labels_train_raw[indx]
images_train = images_train[indx]
assert len(labels_train_raw)==len(images_train)

if args.visualize_fake_images or args.comp_FID:
    indx = np.where((labels_all>q1)*(labels_all<q2)==True)[0]
    labels_all = labels_all[indx]
    images_all = images_all[indx]
    assert len(labels_all)==len(images_all)

### show some real  images
if args.show_real_imgs:
    unique_labels_show = np.array(sorted(list(set(labels_all))))
    indx_show = np.arange(0, len(unique_labels_show), len(unique_labels_show)//9)
    unique_labels_show = unique_labels_show[indx_show]
    nrow = len(unique_labels_show); ncol = 1
    sel_labels_indx = []
    for i in range(nrow):
        curr_label = unique_labels_show[i]
        indx_curr_label = np.where(labels_all==curr_label)[0]
        np.random.shuffle(indx_curr_label)
        indx_curr_label = indx_curr_label[0:ncol]
        sel_labels_indx.extend(list(indx_curr_label))
    sel_labels_indx = np.array(sel_labels_indx)
    images_show = images_all[sel_labels_indx]
    print(images_show.mean())
    images_show = (images_show/255.0-0.5)/0.5
    images_show = torch.from_numpy(images_show)
    save_image(images_show.data, save_images_folder +'/real_images_grid_{}x{}.png'.format(nrow, ncol), nrow=ncol, normalize=True)

# for each angle, take no more than args.max_num_img_per_label images
# image_num_threshold = args.max_num_img_per_label
# print("\n Original set has {} images; For each angle, take no more than {} images>>>".format(len(images_train), image_num_threshold))
# unique_labels_tmp = np.sort(np.array(list(set(labels_train_raw))))
# for i in tqdm(range(len(unique_labels_tmp))):
#     indx_i = np.where(labels_train_raw == unique_labels_tmp[i])[0]
#     if len(indx_i)>image_num_threshold:
#         np.random.shuffle(indx_i)
#         indx_i = indx_i[0:image_num_threshold]
#     if i == 0:
#         sel_indx = indx_i
#     else:
#         sel_indx = np.concatenate((sel_indx, indx_i))
# images_train = images_train[sel_indx]
# labels_train_raw = labels_train_raw[sel_indx]
print("{} images left and there are {}, {} unique labels".format(len(images_train), len(set(labels_train_raw[:, 0])), len(set(labels_train_raw[:, 1]))))

# normalize labels_train_raw
print("\n Range of unnormalized labels for first axis: ({},{})".format(np.min(labels_train_raw[:, 0]), np.max(labels_train_raw[:, 0])))
print("\n Range of unnormalized labels for second axis: ({},{})".format(np.min(labels_train_raw[:, 1]), np.max(labels_train_raw[:, 1])))

if args.GAN == "cGAN": #treated as classification; convert angles to class labels
    unique_labels = np.sort(np.array(list(set(labels_train_raw))))
    num_unique_labels = len(unique_labels)
    print("{} unique labels are split into {} classes".format(num_unique_labels, args.cGAN_num_classes))

    ## convert steering angles to class labels and vice versa
    ### step 1: prepare two dictionaries
    label2class = dict()
    class2label = dict()
    num_labels_per_class = num_unique_labels//args.cGAN_num_classes
    class_cutoff_points = [unique_labels[0]] #the cutoff points on [min_label, max_label] to determine classes
    curr_class = 0
    for i in range(num_unique_labels):
        label2class[unique_labels[i]]=curr_class
        if (i+1)%num_labels_per_class==0 and (curr_class+1)!=args.cGAN_num_classes:
            curr_class += 1
            class_cutoff_points.append(unique_labels[i+1])
    class_cutoff_points.append(unique_labels[-1])
    assert len(class_cutoff_points)-1 == args.cGAN_num_classes

    for i in range(args.cGAN_num_classes):
        class2label[i] = (class_cutoff_points[i]+class_cutoff_points[i+1])/2

    ### step 2: convert angles to class labels
    labels_new = -1*np.ones(len(labels_train_raw))
    for i in range(len(labels_train_raw)):
        labels_new[i] = label2class[labels_train_raw[i]]
    assert np.sum(labels_new<0)==0
    labels_train = labels_new
    del labels_new; gc.collect()
    unique_labels = np.sort(np.array(list(set(labels_train)))).astype(int)
    assert len(unique_labels) == args.cGAN_num_classes

elif args.GAN == "CcGAN":
    if labels_train_raw.shape[-1] == NUM_CONDITIONS:
        labels_train = labels_train_raw / [args.max_label, args.max_label_scale]
    else:
        labels_train = labels_train_raw / args.max_label

    print("\n Range of normalized labels: ({},{})".format(np.min(labels_train), np.max(labels_train)))

    # normalised 된 조건 2개에 대해서 진행
    # unique_labels_norm = np.sort(np.array(list(set(labels_train[:, 0]))))

    # if args.kernel_sigma<0:
    #     std_label = np.std(labels_train)
    #     args.kernel_sigma = 1.06*std_label*(len(labels_train))**(-1/5)

    #     print("\n Use rule-of-thumb formula to compute kernel_sigma >>>")
    #     print("\n The std of {} labels is {} so the kernel sigma is {}".format(len(labels_train), std_label, args.kernel_sigma))

    # if args.kappa<0:
    #     n_unique = len(unique_labels_norm)

    #     diff_list = []
    #     for i in range(1,n_unique):
    #         diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
    #     kappa_base = np.abs(args.kappa)*np.max(np.array(diff_list))

    #     if args.threshold_type=="hard":
    #         args.kappa = kappa_base
    #     else:
    #         args.kappa = 1/kappa_base**2

    unique_labels_norm_1 = np.sort(np.array(list(set(labels_train[:, 0]))))
    unique_labels_norm_2 = np.sort(np.array(list(set(labels_train[:, 1]))))
    unique_labels_norm = np.zeros((len(unique_labels_norm_1) * len(unique_labels_norm_2), 2))
    for idx_1, c_1 in enumerate(unique_labels_norm_1):
        for idx_2, c_2 in enumerate(unique_labels_norm_2):
            unique_labels_norm[idx_1*len(unique_labels_norm_2) + idx_2, 0] = c_1
            unique_labels_norm[idx_1*len(unique_labels_norm_2) + idx_2, 1] = c_2

    if args.kernel_sigma<0:
        std_label = np.std(labels_train)
        args.kernel_sigma = 1.06*std_label*(len(labels_train))**(-1/5)

        print("\n Use rule-of-thumb formula to compute kernel_sigma >>>")
        print("\n The std of {} labels is {} so the kernel sigma is {}".format(len(labels_train), std_label, args.kernel_sigma))

    # TODO: 해당 부분은 조건 1개에 대해서만 일단 진행
    if args.kappa<0:
        n_unique = len(unique_labels_norm)

        diff_list = []
        for i in range(1,n_unique):
            diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
        kappa_base = np.abs(args.kappa)*np.max(np.array(diff_list))

        if args.threshold_type=="hard":
            args.kappa = kappa_base
        else:
            args.kappa = 1/kappa_base**2

elif args.GAN == "cGAN-concat":
    labels_train = labels_train_raw / args.max_label
    print("\n Range of normalized labels: ({},{})".format(np.min(labels_train), np.max(labels_train)))
else:
    raise ValueError('Not supported')
## end if args.GAN

#######################################################################################
'''               Pre-trained CNN and GAN for label embedding                       '''
#######################################################################################
if args.GAN == "CcGAN":
    net_embed_filename_ckpt = os.path.join(path_to_embed_models, 'ckpt_{}_epoch_{}_seed_{}.pth'.format(args.net_embed, args.epoch_cnn_embed, args.seed))
    net_y2h_filename_ckpt = os.path.join(path_to_embed_models, 'ckpt_net_y2h_epoch_{}_seed_{}.pth'.format(args.epoch_net_y2h, args.seed))

    print("\n "+net_embed_filename_ckpt)
    print("\n "+net_y2h_filename_ckpt)

    trainset = IMGs_dataset(images_train, labels_train, normalize=True)
    trainloader_embed_net = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_embed, shuffle=True, num_workers=args.num_workers)

    if args.net_embed == "ResNet18_embed":
        net_embed = ResNet18_embed(dim_embed=args.dim_embed)
    elif args.net_embed == "ResNet34_embed":
        net_embed = ResNet34_embed(dim_embed=args.dim_embed)
    elif args.net_embed == "ResNet50_embed":
        net_embed = ResNet50_embed(dim_embed=args.dim_embed)
    net_embed = net_embed.cuda()
    net_embed = nn.DataParallel(net_embed)

    net_y2h = model_y2h(dim_embed=args.dim_embed)
    net_y2h = net_y2h.cuda()
    net_y2h = nn.DataParallel(net_y2h)

    ## (1). Train net_embed first: x2h+h2y
    if not os.path.isfile(net_embed_filename_ckpt):
        print("\n Start training CNN for label embedding >>>")
        net_embed = train_net_embed(net=net_embed, net_name=args.net_embed, trainloader=trainloader_embed_net, testloader=None, epochs=args.epoch_cnn_embed, resume_epoch = args.resumeepoch_cnn_embed, lr_base=base_lr_x2y, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = path_to_embed_models)
        # save model
        torch.save({
        'net_state_dict': net_embed.state_dict(),
        }, net_embed_filename_ckpt)
    else:
        print("\n net_embed ckpt already exists")
        print("\n Loading...")
        checkpoint = torch.load(net_embed_filename_ckpt)
        net_embed.load_state_dict(checkpoint['net_state_dict'])
    #end not os.path.isfile

    ## (2). Train y2h
    #train a net which maps a label back to the embedding space
    if not os.path.isfile(net_y2h_filename_ckpt):
        print("\n Start training net_y2h >>>")
        net_y2h = train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=args.epoch_net_y2h, lr_base=base_lr_y2h, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128)
        # save model
        torch.save({
        'net_state_dict': net_y2h.state_dict(),
        }, net_y2h_filename_ckpt)
    else:
        print("\n net_y2h ckpt already exists")
        print("\n Loading...")
        checkpoint = torch.load(net_y2h_filename_ckpt)
        net_y2h.load_state_dict(checkpoint['net_state_dict'])
    #end not os.path.isfile

    ##some simple test
    indx_tmp = np.arange(len(unique_labels_norm))
    np.random.shuffle(indx_tmp)
    indx_tmp = indx_tmp[:10]
    labels_tmp = unique_labels_norm[indx_tmp].reshape(-1,NUM_CONDITIONS)
    labels_tmp = torch.from_numpy(labels_tmp).type(torch.float).cuda()
    epsilons_tmp = np.random.normal(0, 0.2, (len(labels_tmp), NUM_CONDITIONS))
    epsilons_tmp = torch.from_numpy(epsilons_tmp).view(-1,NUM_CONDITIONS).type(torch.float).cuda()
    labels_tmp = torch.clamp(labels_tmp+epsilons_tmp, 0.0, 1.0)
    net_embed.eval()
    net_h2y = net_embed.module.h2y
    net_y2h.eval()
    with torch.no_grad():
        labels_rec_tmp = net_h2y(net_y2h(labels_tmp)).cpu().numpy().reshape(-1,NUM_CONDITIONS)
    # results = np.concatenate((labels_tmp.cpu().numpy(), labels_rec_tmp), axis=1)
    print("\n labels vs reconstructed labels")
    # print(results)
    print(labels_tmp)
    print()
    print(labels_rec_tmp)

    #put models on cpu
    net_embed = net_embed.cpu()
    net_h2y = net_h2y.cpu()
    del net_embed, net_h2y; gc.collect()
    net_y2h = net_y2h.cpu()

#######################################################################################
'''                                    GAN training                                 '''
#######################################################################################
if args.GAN == 'CcGAN':
    print("CcGAN: {}, {}, Sigma is {}, Kappa is {}.".format(args.GAN_arch, args.threshold_type, args.kernel_sigma, args.kappa))
    save_images_in_train_folder = save_images_folder + '/{}_{}_{}_{}_in_train'.format(args.GAN_arch, args.threshold_type, args.kernel_sigma, args.kappa)
elif args.GAN == "cGAN":
    print("cGAN: {}, {} classes.".format(args.GAN_arch, args.cGAN_num_classes))
    save_images_in_train_folder = save_images_folder + '/{}_{}_in_train'.format(args.GAN_arch, args.cGAN_num_classes)
elif args.GAN == "cGAN-concat":
    print("cGAN-concat: {}.".format(args.GAN_arch))
    save_images_in_train_folder = save_images_folder + '/{}_in_train'.format(args.GAN_arch)
os.makedirs(save_images_in_train_folder, exist_ok=True)

start = timeit.default_timer()
print("\n Begin Training %s:" % args.GAN)
#----------------------------------------------
# cGAN: treated as a classification dataset
if args.GAN == "cGAN":
    Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_nclass_{}_seed_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.cGAN_num_classes, args.seed)
    print(Filename_GAN)

    if not os.path.isfile(Filename_GAN):
        print("There are {} unique labels".format(len(unique_labels)))

        if args.GAN_arch=="SAGAN":
            netG = cGAN_SAGAN_Generator(z_dim=args.dim_gan, num_classes=args.cGAN_num_classes)
            netD = cGAN_SAGAN_Discriminator(num_classes=args.cGAN_num_classes)
        else:
            raise ValueError('Do not support!!!')
        netG = nn.DataParallel(netG)
        netD = nn.DataParallel(netD)

        # Start training
        netG, netD = train_cgan(images_train, labels_train, netG, netD, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)

        # store model
        torch.save({
            'netG_state_dict': netG.state_dict(),
        }, Filename_GAN)
    else:
        print("Loading pre-trained generator >>>")
        checkpoint = torch.load(Filename_GAN)
        netG = cGAN_SAGAN_Generator(z_dim=args.dim_gan, num_classes=args.cGAN_num_classes).cuda()
        netG = nn.DataParallel(netG)
        netG.load_state_dict(checkpoint['netG_state_dict'])

    # function for sampling from a trained GAN
    def fn_sampleGAN_given_labels(labels, batch_size):
        labels = labels*args.max_label
        fake_images, fake_labels = sample_cgan_given_labels(netG, labels, class_cutoff_points=class_cutoff_points, batch_size = batch_size)
        fake_labels = fake_labels / args.max_label
        return fake_images, fake_labels

#----------------------------------------------
# cGAN: simple concatenation
elif args.GAN == "cGAN-concat":
    Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_seed_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.seed)
    print(Filename_GAN)

    if not os.path.isfile(Filename_GAN):
        if args.GAN_arch=="SAGAN":
            netG = cGAN_concat_SAGAN_Generator(z_dim=args.dim_gan)
            netD = cGAN_concat_SAGAN_Discriminator()
        else:
            raise ValueError('Do not support!!!')
        netG = nn.DataParallel(netG)
        netD = nn.DataParallel(netD)

        # Start training
        netG, netD = train_cgan_concat(images_train, labels_train, netG, netD, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)

        # store model
        torch.save({
            'netG_state_dict': netG.state_dict(),
        }, Filename_GAN)
    else:
        print("Loading pre-trained generator >>>")
        checkpoint = torch.load(Filename_GAN)
        netG = cGAN_concat_SAGAN_Generator(z_dim=args.dim_gan).cuda()
        netG = nn.DataParallel(netG)
        netG.load_state_dict(checkpoint['netG_state_dict'])

    # function for sampling from a trained GAN
    def fn_sampleGAN_given_labels(labels, batch_size):
        labels = labels*args.max_label
        fake_images, fake_labels = sample_cgan_concat_given_labels(netG, labels, batch_size = batch_size, denorm=True, to_numpy=True, verbose=True)
        fake_labels = fake_labels / args.max_label
        return fake_images, fake_labels

#----------------------------------------------
# Concitnuous cGAN
elif args.GAN == "CcGAN":
    Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_seed_{}_{}_{}_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.seed, args.threshold_type, args.kernel_sigma, args.kappa)
    print(Filename_GAN)

    if not os.path.isfile(Filename_GAN):
        netG = CcGAN_SAGAN_Generator(dim_z=args.dim_gan, dim_embed=args.dim_embed)
        netD = CcGAN_SAGAN_Discriminator(dim_embed=args.dim_embed)
        netG = nn.DataParallel(netG)
        netD = nn.DataParallel(netD)

        # Start training
        netG, netD = train_ccgan(args.kernel_sigma, args.kappa, images_train, labels_train, netG, netD, net_y2h, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)

        # store model
        torch.save({
            'netG_state_dict': netG.state_dict(),
        }, Filename_GAN)

    else:
        print("Loading pre-trained generator >>>")
        checkpoint = torch.load(Filename_GAN)
        netG = CcGAN_SAGAN_Generator(dim_z=args.dim_gan, dim_embed=args.dim_embed).cuda()
        netG = nn.DataParallel(netG)
        netG.load_state_dict(checkpoint['netG_state_dict'])

    def fn_sampleGAN_given_labels(labels, batch_size):
        fake_images, fake_labels = sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = batch_size, to_numpy=True, denorm=True, verbose=True)
        return fake_images, fake_labels

stop = timeit.default_timer()
print("GAN training finished; Time elapses: {}s".format(stop - start))

#######################################################################################
'''                                  Evaluation                                     '''
#######################################################################################
if args.comp_FID:
    print("\n Evaluation in Mode {}...".format(args.eval_mode))

    PreNetFID = encoder(dim_bottleneck=512).cuda()
    PreNetFID = nn.DataParallel(PreNetFID)
    Filename_PreCNNForEvalGANs = args.eval_ckpt_path + '/ckpt_AE_epoch_200_seed_2020_CVMode_False.pth'
    checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs)
    PreNetFID.load_state_dict(checkpoint_PreNet['net_encoder_state_dict'])

    # Diversity: entropy of predicted races within each eval center
    PreNetDiversity = ResNet34_class_eval(num_classes=49, ngpu = torch.cuda.device_count()).cuda() #49 chair types
    Filename_PreCNNForEvalGANs_Diversity = args.eval_ckpt_path + '/ckpt_PreCNNForEvalGANs_ResNet34_class_epoch_200_seed_2020_classify_49_chair_types_CVMode_False.pth'
    checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs_Diversity)
    PreNetDiversity.load_state_dict(checkpoint_PreNet['net_state_dict'])

    # for LS
    PreNetLS = ResNet34_regre_eval(ngpu = torch.cuda.device_count()).cuda()
    Filename_PreCNNForEvalGANs_LS = args.eval_ckpt_path + '/ckpt_PreCNNForEvalGANs_ResNet34_regre_epoch_200_seed_2020_CVMode_False.pth'
    checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs_LS)
    PreNetLS.load_state_dict(checkpoint_PreNet['net_state_dict'])

    #####################
    # generate nfake images
    print("\n Start sampling {} fake images per label from GAN >>>".format(args.nfake_per_label))

    if args.eval_mode == 1: #Mode 1: eval on unique labels used for GAN training
        eval_labels = np.sort(np.array(list(set(labels_train_raw)))) #not normalized
    elif args.eval_mode in [2, 3]: #Mode 2 and 3: eval on all unique labels in the dataset
        eval_labels = np.sort(np.array(list(set(labels_all)))) #not normalized
    else: #Mode 4: eval on a interval [min_label, max_label] with num_eval_labels labels
        eval_labels = np.linspace(np.min(labels_all), np.max(labels_all), args.num_eval_labels) #not normalized

    unique_eval_labels = list(set(eval_labels))
    print("\n There are {} unique eval labels.".format(len(unique_eval_labels)))

    eval_labels_norm = eval_labels/args.max_label #normalized

    for i in range(len(eval_labels)):
        curr_label = eval_labels_norm[i]
        if i == 0:
            fake_labels_assigned = np.ones(args.nfake_per_label)*curr_label
        else:
            fake_labels_assigned = np.concatenate((fake_labels_assigned, np.ones(args.nfake_per_label)*curr_label))
    fake_images, _ = fn_sampleGAN_given_labels(fake_labels_assigned, args.samp_batch_size)
    assert len(fake_images) == args.nfake_per_label*len(eval_labels)
    assert len(fake_labels_assigned) == args.nfake_per_label*len(eval_labels)
    assert fake_images.min()>=0 and fake_images.max()<=255.0

    ## dump fake images for computing NIQE
    if args.dump_fake_for_NIQE:
        print("\n Dumping fake images for NIQE...")
        dump_fake_images_folder = save_images_folder + '/fake_images_for_NIQE_nfake_{}'.format(len(fake_images))
        os.makedirs(dump_fake_images_folder, exist_ok=True)
        for i in tqdm(range(len(fake_images))):
            label_i = fake_labels_assigned[i]*args.max_label
            filename_i = dump_fake_images_folder + "/{}_{}.png".format(i, label_i)
            os.makedirs(os.path.dirname(filename_i), exist_ok=True)
            image_i = fake_images[i].astype(np.uint8)
            # image_i = ((image_i*0.5+0.5)*255.0).astype(np.uint8)
            image_i_pil = Image.fromarray(image_i.transpose(1,2,0))
            image_i_pil.save(filename_i)
        #end for i
        # sys.exit()

    print("End sampling! We got {} fake images.".format(len(fake_images)))

    #####################
    # prepare real/fake images and labels
    if args.eval_mode in [1, 3]:
        # real_images = (images_train/255.0-0.5)/0.5
        real_images = images_train
        real_labels = labels_train_raw #not normalized
    else: #for both mode 2 and 4
        # real_images = (images_all/255.0-0.5)/0.5
        real_images = images_all
        real_labels = labels_all #not normalized
    # fake_images = (fake_images/255.0-0.5)/0.5

    #######################
    # For each label take nreal_per_label images
    unique_labels_real = np.sort(np.array(list(set(real_labels))))
    indx_subset = []
    for i in range(len(unique_labels_real)):
        label_i = unique_labels_real[i]
        indx_i = np.where(real_labels==label_i)[0]
        np.random.shuffle(indx_i)
        if args.nreal_per_label>1:
            indx_i = indx_i[0:args.nreal_per_label]
        indx_subset.append(indx_i)
    indx_subset = np.concatenate(indx_subset)
    real_images = real_images[indx_subset]
    real_labels = real_labels[indx_subset]

    nfake_all = len(fake_images)
    nreal_all = len(real_images)

    #####################
    # Evaluate FID within a sliding window with a radius R on the label's range (not normalized range, i.e., [min_label,max_label]). The center of the sliding window locate on [min_label+R,...,max_label-R].
    if args.eval_mode == 1:
        center_start = np.min(labels_train_raw)+args.FID_radius ##bug???
        center_stop = np.max(labels_train_raw)-args.FID_radius
    else:
        center_start = np.min(labels_all)+args.FID_radius
        center_stop = np.max(labels_all)-args.FID_radius

    if args.FID_num_centers<=0 and args.FID_radius==0: #completely overlap
        centers_loc = eval_labels #not normalized
    elif args.FID_num_centers>0:
        centers_loc = np.linspace(center_start, center_stop, args.FID_num_centers) #not normalized
    else:
        print("\n Error.")
    FID_over_centers = np.zeros(len(centers_loc))
    entropies_over_centers = np.zeros(len(centers_loc)) # entropy at each center
    labelscores_over_centers = np.zeros(len(centers_loc)) #label score at each center
    num_realimgs_over_centers = np.zeros(len(centers_loc))
    for i in range(len(centers_loc)):
        center = centers_loc[i]
        interval_start = (center - args.FID_radius)#/args.max_label
        interval_stop = (center + args.FID_radius)#/args.max_label
        indx_real = np.where((real_labels>=interval_start)*(real_labels<=interval_stop)==True)[0]
        np.random.shuffle(indx_real)
        real_images_curr = real_images[indx_real]
        real_images_curr = (real_images_curr/255.0-0.5)/0.5
        num_realimgs_over_centers[i] = len(real_images_curr)
        indx_fake = np.where((fake_labels_assigned>=(interval_start/args.max_label))*(fake_labels_assigned<=(interval_stop/args.max_label))==True)[0]
        np.random.shuffle(indx_fake)
        fake_images_curr = fake_images[indx_fake]
        fake_images_curr = (fake_images_curr/255.0-0.5)/0.5
        fake_labels_assigned_curr = fake_labels_assigned[indx_fake]
        # FID
        FID_over_centers[i] = cal_FID(PreNetFID, real_images_curr, fake_images_curr, batch_size = 200, resize = None)
        # Entropy of predicted class labels
        predicted_class_labels = predict_class_labels(PreNetDiversity, fake_images_curr, batch_size=200, num_workers=args.num_workers)
        entropies_over_centers[i] = compute_entropy(predicted_class_labels)
        # Label score
        labelscores_over_centers[i], _ = cal_labelscore(PreNetLS, fake_images_curr, fake_labels_assigned_curr, min_label_before_shift=0, max_label_after_shift=args.max_label, batch_size = 500, resize = None, num_workers=args.num_workers)

        print("\n [{}/{}] Center:{}; Real:{}; Fake:{}; FID:{}; LS:{}; ET:{}.".format(i+1, len(centers_loc), center, len(real_images_curr), len(fake_images_curr), FID_over_centers[i], labelscores_over_centers[i], entropies_over_centers[i]))
    # end for i
    # average over all centers
    print("\n {} SFID: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(FID_over_centers), np.std(FID_over_centers), np.min(FID_over_centers), np.max(FID_over_centers)))
    print("\n {} LS over centers: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(labelscores_over_centers), np.std(labelscores_over_centers), np.min(labelscores_over_centers), np.max(labelscores_over_centers)))
    print("\n {} entropy over centers: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(entropies_over_centers), np.std(entropies_over_centers), np.min(entropies_over_centers), np.max(entropies_over_centers)))

    # dump FID versus number of samples (for each center) to npy
    dump_fid_ls_entropy_over_centers_filename = os.path.join(path_to_output, 'fid_ls_entropy_over_centers')
    np.savez(dump_fid_ls_entropy_over_centers_filename, fids=FID_over_centers, labelscores=labelscores_over_centers, entropies=entropies_over_centers, nrealimgs=num_realimgs_over_centers, centers=centers_loc)

    #####################
    # FID: Evaluate FID on all fake images
    indx_shuffle_real = np.arange(nreal_all); np.random.shuffle(indx_shuffle_real)
    indx_shuffle_fake = np.arange(nfake_all); np.random.shuffle(indx_shuffle_fake)
    FID = cal_FID(PreNetFID, real_images[indx_shuffle_real], fake_images[indx_shuffle_fake], batch_size = 200, resize = None, norm_img = True)
    print("\n {}: FID of {} fake images: {}.".format(args.GAN_arch, nfake_all, FID))

    #####################
    # Overall LS: abs(y_assigned - y_predicted)
    ls_mean_overall, ls_std_overall = cal_labelscore(PreNetLS, fake_images, fake_labels_assigned, min_label_before_shift=0, max_label_after_shift=args.max_label, batch_size = 200, resize = None, norm_img = True, num_workers=args.num_workers)
    print("\n {}: overall LS of {} fake images: {}({}).".format(args.GAN_arch, nfake_all, ls_mean_overall, ls_std_overall))

    #####################
    # Dump evaluation results
    eval_results_logging_fullpath = os.path.join(path_to_output, 'eval_results_{}.txt'.format(args.GAN_arch))
    if not os.path.isfile(eval_results_logging_fullpath):
        eval_results_logging_file = open(eval_results_logging_fullpath, "w")
        eval_results_logging_file.close()
    with open(eval_results_logging_fullpath, 'a') as eval_results_logging_file:
        eval_results_logging_file.write("\n===================================================================================================")
        eval_results_logging_file.write("\n Eval Mode: {}; Radius: {}; # Centers: {}.  \n".format(args.eval_mode, args.FID_radius, args.FID_num_centers))
        print(args, file=eval_results_logging_file)
        eval_results_logging_file.write("\n SFID: {}({}).".format(np.mean(FID_over_centers), np.std(FID_over_centers)))
        eval_results_logging_file.write("\n LS: {}({}).".format(np.mean(labelscores_over_centers), np.std(labelscores_over_centers)))
        eval_results_logging_file.write("\n Diversity: {}({}).".format(np.mean(entropies_over_centers), np.std(entropies_over_centers)))

#######################################################################################
'''               Visualize fake images of the trained GAN                          '''
#######################################################################################
if args.visualize_fake_images:

    # First, visualize conditional generation # vertical grid
    ## 10 rows; 3 columns (3 samples for each age)
    n_row = 10
    n_col = 10

    displayed_unique_labels = np.sort(np.array(list(set(labels_all))))
    displayed_labels_indx = (np.linspace(0.05, 0.95, n_row)*len(displayed_unique_labels)).astype(int)
    displayed_labels = displayed_unique_labels[displayed_labels_indx] #not normalized
    displayed_normalized_labels = displayed_labels/args.max_label

    ### output fake images from a trained GAN
    filename_fake_images = os.path.join(save_images_folder, 'fake_images_grid_{}x{}.png').format(n_row, n_col)
    fake_labels_assigned = []
    for tmp_i in range(len(displayed_normalized_labels)):
        curr_label = displayed_normalized_labels[tmp_i]
        fake_labels_assigned.append(np.ones(shape=[n_col, 1])*curr_label)
    fake_labels_assigned = np.concatenate(fake_labels_assigned, axis=0)
    images_show, _ = fn_sampleGAN_given_labels(fake_labels_assigned, args.samp_batch_size)
    images_show = (images_show/255.0-0.5)/0.5
    images_show = torch.from_numpy(images_show)
    save_image(images_show.data, filename_fake_images, nrow=n_col, normalize=True)

    if args.GAN == "CcGAN":
        # Second, fix z but increase y; check whether there is a continuous change, only for CcGAN
        n_continuous_labels = 10
        normalized_continuous_labels = np.linspace(0.05, 0.95, n_continuous_labels)
        z = torch.randn(1, args.dim_gan, dtype=torch.float).cuda()
        continuous_images_show = torch.zeros(n_continuous_labels, args.num_channels, args.img_size, args.img_size, dtype=torch.float)
        netG.eval()
        with torch.no_grad():
            for i in range(n_continuous_labels):
                y = np.ones(1) * normalized_continuous_labels[i]
                y = torch.from_numpy(y).type(torch.float).view(-1,1).cuda()
                fake_image_i = netG(z, net_y2h(y))
                continuous_images_show[i,:,:,:] = fake_image_i.cpu()
        filename_continous_fake_images = os.path.join(save_images_folder, 'continuous_fake_images_grid.png')
        save_image(continuous_images_show.data, filename_continous_fake_images, nrow=n_continuous_labels, normalize=True)
        print("Continuous ys: ", (normalized_continuous_labels*args.max_label))

    ### output some real images as baseline
    filename_real_images = save_images_folder + '/real_images_grid_{}x{}.png'.format(n_row, n_col)
    if not os.path.isfile(filename_real_images):
        images_show = np.zeros((n_row*n_col, args.num_channels, args.img_size, args.img_size))
        for i_row in range(n_row):
            # generate 3 real images from each interval
            curr_label = displayed_labels[i_row]
            for j_col in range(n_col):
                indx_curr_label = np.where(labels_all==curr_label)[0]
                np.random.shuffle(indx_curr_label)
                indx_curr_label = indx_curr_label[0]
                images_show[i_row*n_col+j_col] = images_all[indx_curr_label]
        #end for i_row
        images_show = (images_show/255.0-0.5)/0.5
        images_show = torch.from_numpy(images_show)
        save_image(images_show.data, filename_real_images, nrow=n_col, normalize=True)

print("\n===================================================================================================")
ashamy97 commented 4 months ago

Oh that's cool! I am glad we are both trying to solve similar problems haha. I will implement what you kindly suggested and let you know if it works. Thank you so much! Also did you happen to modify CcGAN_SAGAN.py at all?

Foundsheep commented 4 months ago

I thought I did modify CcGAN_SAGAN.py, but can't see any legacy for it in the commit history or git diff. Perhaps, the file wasn't necessary to be modified for our situation.

ashamy97 commented 4 months ago

I see makes sense. And once you trained it, how did you use the generator model?

Foundsheep commented 4 months ago

I used Jupyter notebook for inference(generating on conditions)

  1. Bring the model architecture code in the notebook
    • It seems like netG and net_y2h are needed. Or perhaps other models as well...
    • Use the checkpoint to load_state_dict for those models
      • For me, the names in the stored state_dict and the model I am making in the jupyter notebook didn't match, so I had to make them match each other. For example, I had to remove 'module.' from the names of all the layers inside the state
  2. Make conditions and input those into the model
  3. Visualise the output

    Below is some of the core parts(in my opinion) in the inference notebook I wrote

Checkpoint

# x2y
x2y_ckpt_path = "../../../../output/embed_models/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_200.pth"
x2y_state_dict = torch.load(x2y_ckpt_path)

x2y_state_dict["new_net_state_dict"] = dict()
for k, v in x2y_state_dict["net_state_dict"].items():
    x2y_state_dict["new_net_state_dict"].update({k.replace("module.", ""):v})

net_x2y.load_state_dict(x2y_state_dict["new_net_state_dict"])

Inference

labels = torch.tensor([[30, 1.0],
                      [60, 1.3],
                      [90, 1.6],
                      [120, 1.9],
                      [150, 2.1],
                      [180, 2.4],
                      [210, 2.7]]).cuda()

with torch.no_grad():
    h = net_y2h(labels)
    print(f"{h.shape = }")

    z = torch.randn((h.shape[0], DIM_GAN)).cuda()
    outputs = netG(z, h)
    print(f"{outputs.shape = }")

    labels_hat, embedded_features = net_x2y(outputs)
UBCDingXin commented 4 months ago

Hi there,

Thanks for your work in generative AI. I'm currently trying to implement CcGAN and just wondering, is there anyway to train the model easily?

I can see there are some python files which receive arguments through shell script files, but they lack requirements.txt and need some configuration to run, as far as I can guess.

It'd be much appreciated if you could provide a way to use your model and see how it works with other custom datasets.

Thanks in advance.

Hi,

Thanks a lot for your interest in our work. In this repository, we only provide the .sh shell script for training the model on Linux. If you want to train the model on Windows, you may need to convert the .sh files into .bat files suitable for Windows. Please refer to https://github.com/UBCDingXin/Dual-NDA or https://github.com/UBCDingXin/CCDM, where the Windows batch scripts are provided.

ashamy97 commented 4 months ago

Thank you @Foundsheep and @UBCDingXin for your help! So I tried to run the main.py and I got the following error

  return F.mse_loss(input, target, reduction=self.reduction)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\main.py:307
    305 if not os.path.isfile(net_embed_filename_ckpt):
    306     print("\n Start training CNN for label embedding >>>")
--> 307     net_embed = train_net_embed(net=net_embed, net_name=args.net_embed, trainloader=trainloader_embed_net, testloader=None, epochs=args.epoch_cnn_embed, resume_epoch = args.resumeepoch_cnn_embed, lr_base=base_lr_x2y, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = path_to_embed_models)
    308     # save model
    309     torch.save({
    310     'net_state_dict': net_embed.state_dict(),
    311     }, net_embed_filename_ckpt)

File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\train_net_for_label_embed.py:59, in train_net_embed(net, net_name, trainloader, testloader, epochs, resume_epoch, lr_base, lr_decay_factor, lr_decay_epochs, weight_decay, path_to_ckpt)
     57 #Forward pass
     58 outputs, _ = net(batch_train_images)
---> 59 loss = criterion(outputs, batch_train_labels)
     61 #backward pass
     62 optimizer.zero_grad()

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\loss.py:520, in MSELoss.forward(self, input, target)
    519 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 520     return F.mse_loss(input, target, reduction=self.reduction)

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:3111, in mse_loss(input, target, size_average, reduce, reduction)
   3108 if size_average is not None or reduce is not None:
   3109     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3111 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
   3112 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\functional.py:72, in broadcast_tensors(*tensors)
     70 if has_torch_function(tensors):
     71     return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 72 return _VF.broadcast_tensors(tensors)

RuntimeError: The size of tensor a (256) must match the size of tensor b (1809408) at non-singleton dimension 0

I know that is an error because the dimensions of the conditional inputs are not the same but I do not understand where the 1809408 came from if I am taking a batch of 256 images then it makes sense that the conditional inputs would be 256 as well. Do you think you know where the problem is?

Foundsheep commented 4 months ago

@UBCDingXin Thanks for the guide to .bat files. I didn't know that way, so I used git bash to invoke that shell script haha.

@ashamy97 It seems like the loss function is flattening the input and tensor. I guess, perhaps a wrong guess, you are using different image size to the one the model is expecting, for example, the model might expect 128x128, but you are using 256x256 or something like that.

ashamy97 commented 4 months ago

yup you were absolutely right. My labels were the wrong dimensions but the images were the right dimensions. Silly me! Thank you so much :)

Also I believe for this section in the main.py code

if labels_train_raw.shape[-1] == NUM_CONDITIONS:
    scale_q1 = args.min_label_scale
    scale_q2 = args.max_label_scale
    indx_1 = np.where((labels_train_raw[:, 0]>q1)*(labels_train_raw[:, 0]<q2)==True)[0]
    indx_2 = np.where((labels_train_raw[:, 1]>scale_q1)*(labels_train_raw[:, 1]<scale_q2)==True)[0]
    indx = np.array(list(set.intersection(set(indx_1), set(indx_2))))
else:
    indx = np.where((labels_train_raw>q1)*(labels_train_raw<q2)==True)[0]

I think this is meant to be the min and max for the second label? It wasn't added to the opts.py file so it gave an error. I added it there and it is running fine.

Foundsheep commented 4 months ago

@ashamy97 Yes, you're correct on it :) Because the original code was assuming only one condition, there was that setting min-max mechanism only for the entire label set, and I thought that would need to be expanded to several conditions as well.

Perhaps if the code is running well, you must have implemented it right. As I mentioned earlier, to expand the number of conditions, we need to figure out and modify the bits where only one condition is assumed. Hope your experiment goes well :)

ashamy97 commented 4 months ago

Will let you know. The CNN embed part seems to be taking a while to train but I am happy it's training haha. Thanks for your help @Foundsheep .

Also @UBCDingXin if you think our approach is correct when it comes to including more than one conditional input, let us know.

ashamy97 commented 4 months ago

@Foundsheep I am not sure if you remember this or not, but after how many epochs when training the cCGAN using the following parameters did you get results?

Screenshot 2024-05-09 140137

Because so far here are my generated images after 800 epochs

800

This is what my images are supposed to look like

1_0

Foundsheep commented 4 months ago

@ashamy97 As you can see, just after 800 epochs is almost at the very beginning of the training. To get something noticeable, I recall that I waited to see until about 2000 epochs and most of the pictures in the grid became noticeable after about 5000 epochs. Also, the grid finally became somewhat like a real picture after about 12000 epochs, and even if I waited until about 30000 epochs, it didn't get better from the one around 12000 epochs.

It might vary depending on the dataset we're training on, but all the training(as I mentioned, I skipped some of the training parts in the shell script, such as cGAN tranining or training for evaluation network etc.) took about 3~4 days on my RTX 4070 Laptop GPU.

To be honest with you, our result didn't match our expectation. Even if we gave two condtiions the model completely ignored one of them, and only tried to give us something looking realistic, rather than according to the conditions, which I guess is mode collapse. We might have needed to set the dataset more neatly that the model could learn easily, but that is where we stopped for this model.

UBCDingXin commented 4 months ago

@Foundsheep @ashamy97 Hi guys,

The current CcGAN is designed specifically for univariate conditions, so it might face challenges when directly transferring to multi-dimensional conditions. For multi-dimensional conditions, you may need to re-design the condition input mechanism and re-define the hard/soft vicinity.

We are working on such a problem but lacking in suitable datasets. Is there any public image dataset that is labeled by multi-dimensional conditions?

ashamy97 commented 4 months ago

@UBCDingXin I dont know if this will help but I found this link (https://height-weight-chart.com/) that contains images of individuals with their corresponding height and weight. Both continuous variables and it has limited images.

Foundsheep commented 4 months ago

@UBCDingXin Thanks for the instrudction. That is exactly what I did when trying to convert it to multi-dimensional conditioning model.

I'm not sure either if this would answer your question, but, actually, I used RC-49 dataset, and the way we used it is scaling the size of the chair and used the scaling ratio as another condition. For example, the condition would be (angle, scaling ratio) vector and scaling ratio of 0.5 would mean the chair size is resized to its half(which would then be cropped and added to the white background of the same size as the original image like 128x128)

ashamy97 commented 4 months ago

@UBCDingXin @Foundsheep I am curious if you guys tried training this cCGAN on a dataset that is limited? Like for example on 100 images?

Foundsheep commented 4 months ago

@ashamy97 I haven't tried it. But curious about its result.

ashamy97 commented 4 months ago

@Foundsheep yeah i think I will try that some other time.

Question: so I trained the cCGAN and I am trying to do the inference but I don't quite know how many models we need. Because isn't x2y maps the input image to a latent space and then that latent space to the regression label y? But we want to use the generator so in that case we need net_y2h and the generator and that's it no? Or am I understanding this wrong?

Foundsheep commented 4 months ago

@ashamy97 yes, you're right in the direction of the guess. I mentioned how I used the generator for inference up there, and here is the link https://github.com/UBCDingXin/improved_CcGAN/issues/12#issuecomment-2097120415

ashamy97 commented 4 months ago

@Foundsheep So I am running into an error when trying to find the reconstructed labels. I copied the same code you used with some modifications.

# x2y
DIM_EMBED = 128
x2y_ckpt_path = ".../output/embed_models/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_200.pth"
x2y_state_dict = torch.load(x2y_ckpt_path)
# for k, v in x2y_state_dict["net_state_dict"].items():
#     x2y_state_dict["new_net_state_dict"].update({k.replace("module.", ""):v})

net_embed = ResNet34_embed(dim_embed=DIM_EMBED)
net_embed = net_embed.cuda()
net_embed = nn.DataParallel(net_embed)
net_embed.load_state_dict(x2y_state_dict['net_state_dict'])

#y2h
y2h_ckpt_path = ".../output/embed_models/ckpt_net_y2h_epoch_500_seed_2020.pth"
y2h_state_dict = torch.load(y2h_ckpt_path)
net_y2h = model_y2h(dim_embed=DIM_EMBED)
net_y2h = net_y2h.cuda()
net_y2h = nn.DataParallel(net_y2h)
net_y2h.load_state_dict(y2h_state_dict['net_state_dict'])

#netG
DIM_GAN = 256
Filename_GAN = ".../output/output_CcGAN_arch_SAGAN/saved_models/SAGAN_soft_2_checkpoint_intrain/checkpoint_20400.pth"
checkpoint = torch.load(Filename_GAN)
netG = CcGAN_SAGAN_Generator(dim_z=DIM_GAN, dim_embed=DIM_EMBED).cuda()
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint['netG_state_dict'])

This is the inference:

labels = torch.tensor([[9, 10]]).cuda()

with torch.no_grad():
    h = net_y2h(labels)
    print(f" shape of h: {h.shape}")

    z = torch.randn((h.shape[0], DIM_GAN)).cuda()
    outputs = netG(z, h)
    print(f"Output shape: {outputs.shape}")

    labels_hat, embedded_features = net_embed(outputs)

print(labels_hat.shape)
print(embedded_features.shape)

This is the error message log:

shape of h: torch.Size([1, 128])
Output shape: torch.Size([1, 3, 128, 128])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 16
     13     outputs = netG(z, h)
     14     print(f"Output shape: {outputs.shape}")
---> 16     labels_hat, embedded_features = net_embed(outputs)
     18 print(labels_hat.shape)
     19 print(embedded_features.shape)

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\parallel\data_parallel.py:166, in DataParallel.forward(self, *inputs, **kwargs)
    163     kwargs = ({},)
    165 if len(self.device_ids) == 1:
--> 166     return self.module(*inputs[0], **kwargs[0])
    167 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
    168 outputs = self.parallel_apply(replicas, inputs, kwargs)

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\models\ResNet_embed.py:131, in ResNet_embed.forward(self, x)
    129 features = self.main(x)
    130 features = features.view(features.size(0), -1)
--> 131 features = self.x2h_res(features)
    132 out = self.h2y(features)
    134 return out, features

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\container.py:141, in Sequential.forward(self, input)
    139 def forward(self, input):
    140     for module in self:
--> 141         input = module(input)
    142     return input

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\batchnorm.py:168, in _BatchNorm.forward(self, input)
    161     bn_training = (self.running_mean is None) and (self.running_var is None)
    163 r"""
    164 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
    165 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
    166 used for normalization (i.e. in eval mode when buffers are not None).
    167 """
--> 168 return F.batch_norm(
    169     input,
    170     # If buffers are not to be tracked, ensure that they won't be updated
    171     self.running_mean
    172     if not self.training or self.track_running_stats
    173     else None,
    174     self.running_var if not self.training or self.track_running_stats else None,
    175     self.weight,
    176     self.bias,
    177     bn_training,
    178     exponential_average_factor,
    179     self.eps,
    180 )

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:2280, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   2267     return handle_torch_function(
   2268         batch_norm,
   2269         (input, running_mean, running_var, weight, bias),
   (...)
   2277         eps=eps,
   2278     )
   2279 if training:
-> 2280     _verify_batch_size(input.size())
   2282 return torch.batch_norm(
   2283     input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
   2284 )

File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:2248, in _verify_batch_size(size)
   2246     size_prods *= size[i + 2]
   2247 if size_prods == 1:
-> 2248     raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])

It's weird because I didn't get error messages when training x2y.

Do you think you might know where the problem is? I think this will then be my last question. Thanks so much for your help :)

Foundsheep commented 3 months ago

@ashamy97 It seems related to something to do with the shape of tensor, which is modified due to our modification on conditioning inputs.

Could you see if there's any difference between the below code and the code you're using in ResNet_embed.py? Perhaps, those commented lines might have something to do with this error.

Basically, I think I faced a similar kind of error, and I resolved it through checking what shape the model is expecting inside line by line. If you're to face this error continuously, that might be helpful :)

My version of ResNet_embed.py

import torch
import torch.nn as nn
import torch.nn.functional as F

NC = 3
IMG_SIZE = 128
DIM_EMBED = 128

NUM_CONDITIONS = 2

#------------------------------------------------------------------------------
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet_embed(nn.Module):
    def __init__(self, block, num_blocks, nc=NC, dim_embed=DIM_EMBED):
        super(ResNet_embed, self).__init__()
        self.in_planes = 64

        self.main = nn.Sequential(
            nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False),  # h=h
            # nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1, bias=False),  # h=h/2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2), #h=h/2 64
            # self._make_layer(block, 64, num_blocks[0], stride=1),  # h=h
            self._make_layer(block, 64, num_blocks[0], stride=2),  # h=h/2 32
            self._make_layer(block, 128, num_blocks[1], stride=2), # h=h/2 16
            self._make_layer(block, 256, num_blocks[2], stride=2), # h=h/2 8
            self._make_layer(block, 512, num_blocks[3], stride=2), # h=h/2 4
            # nn.AvgPool2d(kernel_size=4)
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.x2h_res = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, dim_embed),
            nn.BatchNorm1d(dim_embed),
            nn.ReLU(),
        )

        self.h2y = nn.Sequential(
            nn.Linear(dim_embed, NUM_CONDITIONS),
            nn.ReLU()
        )

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):

        features = self.main(x)
        features = features.view(features.size(0), -1)
        features = self.x2h_res(features)
        out = self.h2y(features)

        return out, features

def ResNet18_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(BasicBlock, [2,2,2,2], dim_embed=dim_embed)

def ResNet34_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(BasicBlock, [3,4,6,3], dim_embed=dim_embed)

def ResNet50_embed(dim_embed=DIM_EMBED):
    return ResNet_embed(Bottleneck, [3,4,6,3], dim_embed=dim_embed)

#------------------------------------------------------------------------------
# map labels to the embedding space
class model_y2h(nn.Module):
    def __init__(self, dim_embed=DIM_EMBED):
        super(model_y2h, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(NUM_CONDITIONS, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            # nn.BatchNorm1d(dim_embed),
            nn.GroupNorm(8, dim_embed),
            nn.ReLU(),

            nn.Linear(dim_embed, dim_embed),
            nn.ReLU()
        )

    def forward(self, y):
        # y = y.view(-1, 1) +1e-8
        y = y + 1e-8
        # y = torch.exp(y.view(-1, 1))
        return self.main(y)

if __name__ == "__main__":
    net = ResNet34_embed(dim_embed=128).cuda()
    x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
    out, features = net(x)
    print(out.size())
    print(features.size())

    net_y2h = model_y2h().cuda()
    y_hat = net_y2h(out)
    print(f"{y_hat.size() = }")