Hi, I tried using your model to try to detect AI generated photos like those from SD, SDXL, Dalle, etc. However, most of the predictions are "real". Do you see any problem with my code?
import clip
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LASTED
LABELS = ["Real Photo", "Synthetic Photo", "Real Painting", "Synthetic Painting"]
def modify_state_dict(sd: dict) -> dict:
new_sd = dict()
for k, v in sd.items():
new_sd[k.replace("module.", "")] = v
return new_sd
def classify(image: Image.Image):
with torch.inference_mode():
tensor_in = transform(image).unsqueeze(0).to(device)
text = clip.tokenize(LABELS).to(device)
image_features = model.clip_model.encode_image(tensor_in)
text_features = model.clip_model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (
(100.0 * image_features @ text_features.T)
.softmax(dim=-1)
.detach()
.cpu()
.numpy()
)
print(f"similarity: {similarity}")
return np.array(LABELS)[np.argmax(similarity, axis=1)].tolist()
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
transform = transforms.Compose(
[
# transforms.ToPILImage(),
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
print("Loading model...")
model = LASTED()
model.load_state_dict(
modify_state_dict(torch.load("LASTED_pretrained.pt", map_location="cpu"))
)
model.eval()
model.to(device)
print("Done!")
demo = gr.Interface(
fn=classify,
inputs=[gr.Image(label="input image", type="pil")],
outputs=[gr.Text(label="predicted label")],
)
demo.launch(server_name="0.0.0.0", server_port=80)
Hi, I tried using your model to try to detect AI generated photos like those from SD, SDXL, Dalle, etc. However, most of the predictions are "real". Do you see any problem with my code?