WisconsinAIVision / UniversalFakeDetect

213 stars 26 forks source link

Inference method #7

Open prnvsheth opened 1 year ago

prnvsheth commented 1 year ago

Is there a code already written to infer whether a specific image is fake or not? Can you point to the appropriate documentation from the code base.

oshita-n commented 9 months ago

Perhaps it could be written like this.

inference.py

import argparse
from ast import arg
import os
import csv
import torch
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
from torch.utils.data import Dataset
import sys
from models import get_model
from PIL import Image 
import pickle
from tqdm import tqdm
from io import BytesIO
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random
import shutil
from scipy.ndimage.filters import gaussian_filter
import torchvision

SEED = 0
def set_seed():
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}

def inference(model, img):
    with torch.no_grad():
        y_pred = model(img).sigmoid().flatten().squeeze().cpu().numpy()
    return y_pred

if __name__ == '__main__':

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--image' , type=str, default='./test_images/real.png')
    parser.add_argument('--arch', type=str, default='res50')
    parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth')

    opt = parser.parse_args()

    model = get_model(opt.arch)
    state_dict = torch.load(opt.ckpt, map_location='cpu')
    model.fc.load_state_dict(state_dict)
    model.eval()
    model.cuda()

    stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip"

    transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
    ])

    img_tensor = transform(Image.open(opt.image).convert("RGB")).unsqueeze(0).cuda()
    y_pred = inference(model, img_tensor)

    print ("Prediction: ", y_pred)
python inference.py  --arch=CLIP:ViT-L/14   --ckpt=pretrained_weights/fc_weights.pth   --image real/image-02.png