google / lecam-gan

Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)
Apache License 2.0
163 stars 18 forks source link

Low-shot training #5

Open xinyouduogao opened 3 years ago

xinyouduogao commented 3 years ago

How to train the Lecam-gan on the low-shot image generation datasets,THX.

hytseng0509 commented 3 years ago

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

liang-hou commented 3 years ago

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

Hi, would you mind sharing the source code of the low-shot generation experiments? It will help us a lot.

aprilycliu commented 3 years ago

Hello, I'd also try to implement lecam loss on DiffAug low shot. Could you share the train script please? Thanks!

SushkoVadim commented 3 years ago

Hi,

I would add my vote to this discussion, it would be very helpful to have a look at the training script or the modified training.loss.py file. I tried to implement the method myself on top of the DiffAugm, but I still did not manage to reproduce the results from suppl. Table 4. Thanks in advance!

roadjiang commented 3 years ago

This experiment only shows in the supplementary and does not belong to the main paper. We'd love to release the code but it may require additional approvals. We will try our best and see what we can do.

SushkoVadim commented 3 years ago

Hi, Thanks a lot for answering! I understand that the clearing process for open sourcing can be time-consuming and burdensome. Potentially to simplify the answer, could I please ask you to share a comment on my attempts to reproduce the training? Perhaps, I did not know some implementation details that appear to be important. This can also be beneficial for others trying to reproduce the results for the low-shot training.

My modification to the DiffAugm was to add the lecam regularizatoin in the training.loss.py module. 1) Particularly, I added a simple EMA tracker for both the real and fake logits to the StyleGAN2Loss Class:

self.val_ema_real = val_EMA()
self.val_ema_fake = val_EMA()

class val_EMA():
    def __init__(self, ema_decay=0.99):
        self.ema_decay = ema_decay
        self.mem_value = 0

    def add_step(self, cur_values):
        self.mem_value = self.ema_decay * self.mem_value + (1 - self.ema_decay) * cur_values.detach()

    def get_cur_val(self):
        return self.mem_value

2) During training I add new logit values to the ema accumulation, and then add the regularization to the objective functions:

    # for fakes
    loss_emaCR_fake = 0
    if do_emaCR:
        self.val_ema_fake.add_step(gen_logits)  
        loss_emaCR_fake = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(self.val_ema_real.get_cur_val() - gen_logits)))

    with torch.autograd.profiler.record_function('Dgen_backward'):
        (loss_Dgen + loss_emaCR_fake).mean().mul(gain).backward()
    ....
    # for reals
    loss_emaCR_real = 0
    if do_emaCR:
        self.val_ema_real.add_step(real_logits)
        loss_emaCR_real = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(real_logits - self.val_ema_fake.get_cur_val())))

    with torch.autograd.profiler.record_function(name + '_backward'):
        (real_logits * 0 + loss_Dreal + loss_Dr1 + loss_emaCR_real).mean().mul(gain).backward()
I run the training for the same 300 kimg, I use self.cr_ema_lambda = 0.0001, self.ema_decay = 0.99, which corresponds to the description from the supplementary material. After the training is finished, I measure the following best FID across epochs: Use LeCam CR? Metrics Animal Face - Cat Animal Face Dog Obama Panda Grumpy Cat
- reported 42.10 58.47 47.09 12.10 27.21
Yes reported 33.16 54.88 33.16 10.16 24.93
- reproduced 40.20 67.12 48.31 14.44 27.09
Yes reproduced 39.55 64.84 50.80 14.82 29.66

Thus, I am able to reproduce the original numbers from DiffAugm repository. However, the results after adding the lecam CR seem not to match to Table 4, this step is even harmful for 3/5 of the datasets.

It would be indeed very helpful if we figure out where lies my misunderstanding. Regards, Vadim