HighwayWu / LASTED

Synthetic Image Detection
MIT License
51 stars 2 forks source link

Model just predicts most images as real (painting or photo) #15

Open xiankgx opened 5 months ago

xiankgx commented 5 months ago

Hi, I tried using your model to try to detect AI generated photos like those from SD, SDXL, Dalle, etc. However, most of the predictions are "real". Do you see any problem with my code?

import clip
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LASTED

LABELS = ["Real Photo", "Synthetic Photo", "Real Painting", "Synthetic Painting"]

def modify_state_dict(sd: dict) -> dict:
    new_sd = dict()
    for k, v in sd.items():
        new_sd[k.replace("module.", "")] = v
    return new_sd

def classify(image: Image.Image):
    with torch.inference_mode():
        tensor_in = transform(image).unsqueeze(0).to(device)
        text = clip.tokenize(LABELS).to(device)

        image_features = model.clip_model.encode_image(tensor_in)
        text_features = model.clip_model.encode_text(text)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (
            (100.0 * image_features @ text_features.T)
            .softmax(dim=-1)
            .detach()
            .cpu()
            .numpy()
        )
        print(f"similarity: {similarity}")

        return np.array(LABELS)[np.argmax(similarity, axis=1)].tolist()

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    transform = transforms.Compose(
        [
            # transforms.ToPILImage(),
            transforms.Resize((448, 448)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

    print("Loading model...")
    model = LASTED()
    model.load_state_dict(
        modify_state_dict(torch.load("LASTED_pretrained.pt", map_location="cpu"))
    )
    model.eval()
    model.to(device)
    print("Done!")

    demo = gr.Interface(
        fn=classify,
        inputs=[gr.Image(label="input image", type="pil")],
        outputs=[gr.Text(label="predicted label")],
    )

    demo.launch(server_name="0.0.0.0", server_port=80)