I just want to get dehazed image,but the origin code test.py is too complex to use,so i try to get the core part
from models.networks import ResnetGenerator,_UNetGenerator
import cv2
from pathlib import Path
from tqdm import tqdm
import torch
from torchvision.transforms import ToTensor,Resize,Compose
from torchvision.utils import save_image
device = torch.device("cuda:0")
model = _UNetGenerator(3,3, 64, 4, 'batch','PReLU',0,False,0,0.1)
state_dict = {k.replace('module.',''):v for k,v in torch.load('./checkpoints/30_netR_Dehazing.pth').items()}
model.load_state_dict(state_dict)
model.eval()
model.cuda()
trans = Compose([
ToTensor(),
Resize((256,256)),
])
img_lists = [str(i) for i in Path('/home/shelgi/ssd/RTTS/JPEGImages').glob('*') if str(i).split(".")[-1] in ['jpg','png','jpeg']]
with torch.no_grad():
for img_path in tqdm(img_lists):
img_name = img_path.split("/")[-1]
img =cv2.imread(img_path)
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img=img.astype('float32')/255.
H,W=img.shape[:2]
img=trans(img).unsqueeze(0).to(device)
out = model(img)[-1]
out=Resize((H,W))(out)
out = out.clamp(0,1)
save_image(out,f'results/{img_name}')
BTW,for caculating the SSIM and PSNR ,base on this code and make a simple dataloader is easy to get result.
I just want to get dehazed image,but the origin code
test.py
is too complex to use,so i try to get the core partBTW,for caculating the SSIM and PSNR ,base on this code and make a simple dataloader is easy to get result.