HUSTSYJ / DA_dahazing

Domain Adaptation for Image Dehazing, CVPR2020
246 stars 40 forks source link

The core code about dehazing #35

Open Yshelgi opened 9 months ago

Yshelgi commented 9 months ago

I just want to get dehazed image,but the origin code test.py is too complex to use,so i try to get the core part

from models.networks import ResnetGenerator,_UNetGenerator

import cv2
from pathlib import Path
from tqdm import tqdm
import torch
from torchvision.transforms import ToTensor,Resize,Compose
from torchvision.utils import save_image

device = torch.device("cuda:0")
model = _UNetGenerator(3,3, 64, 4, 'batch','PReLU',0,False,0,0.1)
state_dict = {k.replace('module.',''):v for k,v in torch.load('./checkpoints/30_netR_Dehazing.pth').items()}
model.load_state_dict(state_dict)
model.eval()
model.cuda()

trans = Compose([
  ToTensor(),
  Resize((256,256)),
])

img_lists = [str(i) for i in Path('/home/shelgi/ssd/RTTS/JPEGImages').glob('*') if str(i).split(".")[-1] in ['jpg','png','jpeg']]
with torch.no_grad():
  for img_path in tqdm(img_lists):
    img_name = img_path.split("/")[-1]
    img =cv2.imread(img_path)
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img=img.astype('float32')/255.
    H,W=img.shape[:2]
    img=trans(img).unsqueeze(0).to(device)

    out = model(img)[-1]
    out=Resize((H,W))(out)
    out = out.clamp(0,1)
    save_image(out,f'results/{img_name}')

BTW,for caculating the SSIM and PSNR ,base on this code and make a simple dataloader is easy to get result.

AerialL commented 3 months ago

Thanks, that's useful.