lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.63k stars 221 forks source link

Multiclass training and inference #50

Open virilo opened 3 years ago

virilo commented 3 years ago

I'd like to lightweight-gan for a multiclass dataset.

The idea is to train the GAN with multiclass.

And during the inference, ask the GAN for an image with multiples tags. I.e. generate an image tagged as 'boat', 'sunset' and 'people'

Is it possible with lightweight-gan?

lucidrains commented 3 years ago

@virilo no, not at the moment, the architecure isn't class conditional

Dok11 commented 3 years ago

@lucidrains how you think, does enough add info about tags into latents before generate images and into discimnator checks? Or it more difficult task?

Like

latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
latents[0] = 0.0  # tag 1 is disabled
latents[1] = 1.0  # tag 2 is enabled
latents[2] = 0.8  # tag 3 is enabled as 80% prob
Mut1nyJD commented 3 years ago

It is pretty straightforward I made a Conditional version in my checkout,

You simply attach a one hot vector size of the number of classes to the latent code and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Dok11 commented 3 years ago

and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented? Losses just add up or it have other logic?

lucidrains commented 3 years ago

yea, it's doable! but i'm focused on Alphafold2 for this week and the next

I'll circle back to this in due time!

Mut1nyJD commented 3 years ago

and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented?

Generator:

In the generator you simply have to add another parameter like number of classes

then change the self.initial_conv to

 nn.Sequential(
            nn.ConvTranspose2d(latent_dim+num_classes, latent_dim * 2, 4),
            norm_class(latent_dim * 2),
            nn.GLU(dim = 1)
        )

In the Discriminator you also should add the number of classes as new parameter. Other than that it is wise to split up the to_logits since it is best if the class and binary real/fake go through the same feature branch, so seperate the last conv layer where it reduces it to 1 channel out and add another conv layer, something like that

           self.to_logits = nn.Sequential(
                Blur(),
                nn.Conv2d(last_chan, last_chan, 3, stride = 2, padding = 1),
                nn.LeakyReLU(0.1),
            )
            self.realfakeConv = nn.Conv2d(last_chan, 1, 4)
            self.classConv = nn.Conv2d(last_chan,self.num_classes,4)

Obviously you have to change the code in the forward function where to_logits is called to

       logits = self.to_logits(x)
        out = self.realfakeConv(logits)
        out_class = self.classConv(logits)

And also return out_class from the forward

The rest is just a bit of pluming

Losses just add up or it have other logic?

Yes of course you could also add another weighting hyperparameter between fake/real loss and class loss

taucontrib commented 3 years ago

Thank you very much @Mut1nyJD, i was just about to write a class conditional version by myself. I thought about implementing class conditional batch normalization and a projection discriminator as in BIGGAN etc because I had better experiences with that technique. Your way seems to be easier to implement though. Did u already train a model with that class conditional extension? How are the results? Would be cool to know before i try it out myself.

Mut1nyJD commented 3 years ago

@xnqio In my tests it works pretty well but I only used dataset where the number of classes is relatively small <20. Yes I like the simplicity of this method instead of projection and class embedding. An additional trick to do is to add an additional class the fake class and when you put examples from the generator to the discriminator network you mark them as fake class. That gives it an even stronger learning signal.

taucontrib commented 3 years ago

@Mut1nyJD thanks for the tipp. Do you add the fake class additionally to real/fake classification or do you get rid of the real_fake conv = nn.Conv2d(chan, 1, ..) and replace it by a class_conv = nn.Conv2d(chan, self.number_classes + 1, ..)?

Mut1nyJD commented 3 years ago

@xnqio Yes I add the fake class additionally and no it is best to keep the binary classifier output in the discriminator as well but let them go though the same feature branch before hand so you have one additional output compared to the unconditioned version. And for that one you can do standard CrossEntropyLoss with num_classes+1.

taucontrib commented 3 years ago

Alright got it @Mut1nyJD One last question: The standard CrossEntropyLoss doesn't work for multi label classification. This means that the generated images only belong to the fake class and not to the one they try to imitate, right?

Mut1nyJD commented 3 years ago

Yes for the output of the generator you set the labels as fake when training step of the discriminator for the generator you leave the real labels intact. Oh one thing I did notice but I have not verified if that is a general problem with this architecture or just with my conditional version but I've seen using AMP training seems to lead more often to mode collapse.

Dok11 commented 3 years ago

If we now talks about tricks then I should put this link here =) https://github.com/soumith/ganhacks