EoinKenny / AAAI-2021

Code for our paper
12 stars 2 forks source link

How to find input vector z for an image #1

Closed ziweizhao1993 closed 2 years ago

ziweizhao1993 commented 2 years ago

Hi,

Thank you for sharing the official implementation! I see in the provided code that the input vector z is loaded from "misclassify_XX.pt", I wonder if the code to find z (equation 1 in the paper) is also available?

Ziwei

EoinKenny commented 2 years ago

Hi Ziwei,

Thanks for the message, I have been looking and it appears my code for that function was lost when my computer wiped a few months ago.

It is one of the easiest functions to implement though. I haven't tested this, but something like this should work

def find_gan_approx(G, original_query_image, original_query_label, cnn):

    z = torch.rand.randn((1, latent_size))
    optimizer = optim.Adam([z], lr=lr)
    mse_loss = torch.nn.MSELoss()
    cce_loss = torch.nn.CrossEntroypy()
    target_label = torch.tensor([0,0,0,0,0,0,0,0,0,0], dtype=torch.float32)
    target_label[original_query_label] = 1

    lambda1 = 1.
    lambda2 = 1.

    for _ in range(10000):

        optimizer.zero_grad()

        pred_logits = cnn(G(z))
        current_img = G(z)

        loss1 = cce_loss(pred_logits, target_label) * lambda1
        loss2 = mse_loss(current_img, original_query_image) * lambda2
        loss = loss1 + loss2
                loss.backward()

        optimizer.step()

    return z

You'll have to tweak lambda1 and lambda2

ziweizhao1993 commented 2 years ago

Thank you!

ziweizhao1993 commented 2 years ago

Hi Eoin,

So sorry for bothering you again...

I just tested this function and it worked very well for most test images. However, I encountered a few test cases where this function failed to find the latent vector z. I was wondering if it is sensitive to the randn initialization? Should I tweak lambda1 and lambda2 for each image individually?

Thank you!

Best, Ziwei

EoinKenny commented 2 years ago

Hi Ziwei,

No problem at all, happy to help.

It is an open research problem how to do this best unfortunately. I remember I setup a "for loop" to do all 100+ digits I used in this paper and let it optimize for a while. In the end I manually inspected them and think I had to redo a few of them (maybe 5-20 I think). Yeah some do well with this function, some I found did better just using the MSE loss in the pixel space and ignoring the logits. I think using the logit loss does help avoid local minimum, but it can also cause problems unfortunately. You could also try an L1 Loss instead of MSE.

I would try to just use the pixel space loss if you're struggling (so set lambda1 = 0).

Sorry I can't help much more than that, if you figure out how to perfectly recover the latent z in a GAN in a reliable way, you will have a research paper waiting to be published :-)

I remember I tried to do this for ImageNet using BigGAN, and it really struggles, which is partly why I stuck to using MNIST and CIFAR-10 (even the latter struggles). There are smarter ways, like by training an encoder-decoder to help (e.g., see the AttGAN paper), but overall no one has a "full-proof" was to get z unfortunately.

ziweizhao1993 commented 2 years ago

Hi Eoin,

I followed your suggestions and found some papers on GAN inversion, it looks like a very interesting area of study. Thank you for your help!

Best, Ziwei