Open prnvsheth opened 1 year ago
Perhaps it could be written like this.
inference.py
import argparse
from ast import arg
import os
import csv
import torch
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
from torch.utils.data import Dataset
import sys
from models import get_model
from PIL import Image
import pickle
from tqdm import tqdm
from io import BytesIO
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random
import shutil
from scipy.ndimage.filters import gaussian_filter
import torchvision
SEED = 0
def set_seed():
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
MEAN = {
"imagenet":[0.485, 0.456, 0.406],
"clip":[0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet":[0.229, 0.224, 0.225],
"clip":[0.26862954, 0.26130258, 0.27577711]
}
def inference(model, img):
with torch.no_grad():
y_pred = model(img).sigmoid().flatten().squeeze().cpu().numpy()
return y_pred
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image' , type=str, default='./test_images/real.png')
parser.add_argument('--arch', type=str, default='res50')
parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth')
opt = parser.parse_args()
model = get_model(opt.arch)
state_dict = torch.load(opt.ckpt, map_location='cpu')
model.fc.load_state_dict(state_dict)
model.eval()
model.cuda()
stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip"
transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
])
img_tensor = transform(Image.open(opt.image).convert("RGB")).unsqueeze(0).cuda()
y_pred = inference(model, img_tensor)
print ("Prediction: ", y_pred)
python inference.py --arch=CLIP:ViT-L/14 --ckpt=pretrained_weights/fc_weights.pth --image real/image-02.png
Is there a code already written to infer whether a specific image is fake or not? Can you point to the appropriate documentation from the code base.