SunnyHaze / IML-ViT

Official repository of paper “IML-ViT: Benchmarking Image manipulation localization by Vision Transformer”
MIT License
181 stars 22 forks source link

Inference #1

Closed thayanesimoes closed 10 months ago

thayanesimoes commented 10 months ago

Hello, do you need ground thuth for inference? How can I remove it from the model predictor?

SunnyHaze commented 10 months ago

Thank you for focusing on our work😄!

I see your point, and I am sorry that our combination of the loss computing process within the iml_vit_model.forward() function brings confusion.

However, we only utilize the ground truth(GT) for computing the loss. Therefore indeed GTs do not matter for inference outcome. You could simply overcome this issue by passing a torch.zeros_like(img) as ground truth and edge mask into our model. Anyway, if you want to train the model, a GT is still needed.

Further, we will consider decoupling this part in the final release code. Please stay tuned.

thayanesimoes commented 10 months ago

Got it, thanks !

Where in the script do I pass the ground truth and edge mask as torch.zeros_like(img)? I tried this before making the prediction in Demo.ipyn and it didn't return anything.

SunnyHaze commented 10 months ago

I have tested it on my computer and found that the solution below should be much better.

For instance, a possible inference block to meet your need in our Demo.ipynb should be like this:

results = []
model.eval()
with torch.no_grad():
    for img, gt, edge_mask, shape in dataset:    # Inference don't need edge mask.
        img, gt, edge_mask = img.to(device), gt.to(device), edge_mask.to(device)
        # Since no Dataloader, manually create a Batch with size==1
        img = img.unsqueeze(0) # CHW -> 1CHW
        gt = gt.unsqueeze(0)
        edge_mask = edge_mask.unsqueeze(0)
        print("shape:", shape)
        # inference with ground truth (will output the loss)
        # predict_loss, mask_pred, edge_loss = model(img, gt, edge_mask)

        # inference without ground truth (loss is meaningless) 
        predict_loss, mask_pred, edge_loss = model(img, torch.zeros(1, 1, 1024, 1024), torch.zeros(1, 1,  1024, 1024))
        print(f"Predict Loss:{predict_loss}, including edge loss: {edge_loss}")
        output = mask_pred

        # visualize
        plt.subplot(1, 3, 1)
        plt.title("Input image")
        plt.imshow(img[0].permute(1, 2, 0).cpu().numpy())
        plt.subplot(1, 3, 2)
        plt.title("Prediction")
        plt.imshow(output.cpu().numpy()[0][0], cmap='gray')

        # Cut the origin area from padded image
        output = output[0, :, 0:shape[0], 0:shape[1]].permute(1, 2, 0).cpu().numpy()
        results.append(output)
        plt.subplot(1, 3, 3)
        plt.title("Cropped region")
        plt.imshow(output, cmap='gray')
        plt.show()
print("Done!")

Note that the GT and edge_mask in our implementation only have a single channel. Therefore my previous comment is incorrect. Please follow this one.

This is the corresponding output:

image

Hope this helps you.

thayanesimoes commented 10 months ago

Yes, it helped, it worked for me.

Thank you very much.

SunnyHaze commented 10 months ago

You are welcome. 😃

SunnyHaze commented 10 months ago

We have added a Google Colab version demo, with an example of IML-ViT to infer only with input images from the Internet. This commit could officially solve this problem.

IML-ViT Colab Demo: Colab