SamsungLabs / semi-supervised-NFs

Code for the paper Semi-Conditional Normalizing Flows for Semi-Supervised Learning
https://arxiv.org/abs/1905.00505
BSD 2-Clause "Simplified" License
28 stars 2 forks source link

Conditional generation from prior #2

Open sgalkina opened 3 months ago

sgalkina commented 3 months ago

Dear authors,

Thank you for making the code available. I am interested in semi-supervised normalizing flows and trying to reproduce the results. I couldn't find the code to generate new samples from scratch, conditioned with the label (like, generate the MNIST image given the number 3). I quickly modified the FlowPDF code the following way according to the architecture description on the Figure 1:

class FlowPDF(nn.Module):
    def __init__(self, flow, prior):
        super().__init__()
        self.flow = flow
        self.prior = prior

    def log_prob(self, x):
        log_det, z = self.flow(x)
        return log_det + self.prior.log_prob(z)

    def conditional_sample(self, y, device):
        N = y.shape[0]
        H = 196
        M = 784
        z_h = torch.normal(0, 1, (N, H)).to(device)
        z_f = self.prior.flow.g(z_h, y)[:, :, 0, 0]
        z_aux = torch.normal(0, 1, (N, M-H)).to(device)
        x = self.flow.module.g(torch.cat([z_f, z_aux], 1))
        return x

but the results are quite poor, nothing like the samples from the Figure 2. Could you please show the code that could be used to generate images, conditioned on labels?

sgalkina commented 3 months ago

Hi again,

I've fixed the code to use the priors:

class FlowPDF(nn.Module):
    def __init__(self, flow, prior):
        super().__init__()
        self.flow = flow
        self.prior = prior

    def log_prob(self, x):
        log_det, z = self.flow(x)
        return log_det + self.prior.log_prob(z)

    def conditional_sample(self, y, device):
        N = y.shape[0]
        H = 196
        M = 784
        mean1 = self.prior.deep_prior.mean
        sigma1 = torch.exp(self.prior.deep_prior.logsigma)
        z_h = torch.normal(mean1.repeat([N, 1]), sigma1.repeat([N, 1])).to(device)
        z_f = self.prior.flow.g(z_h, y)[:, :, 0, 0]
        mean2 = self.prior.shallow_prior.mean
        sigma2 = torch.exp(self.prior.shallow_prior.logsigma)
        z_aux = torch.normal(mean2.repeat([N, 1]), sigma2.repeat([N, 1])).to(device)
        x = self.flow.module.g(torch.cat([z_aux, z_f], 1))
        return x

and got these images, which is rather decent:

So the issue is basically closed for me, but let me know if you have any additional comments.