Zhanghb1688 / RatUNet

0 stars 0 forks source link

Code #1

Open Z-FL opened 2 years ago

Z-FL commented 2 years ago

Hi, I want to konw When the code will be made public in this paper. Looking forward to your reply,thank you very much.

Zhanghb1688 commented 2 years ago

My code is also public. Since I won't use github, I don't put the code and parameter files together. The code address is: https://github.com/Zhanghb1688/Zhanghb1688

------------------ 原始邮件 ------------------ 发件人: "Zhanghb1688/RatUNet" @.>; 发送时间: 2022年6月22日(星期三) 下午4:35 @.>; @.***>; 主题: [Zhanghb1688/RatUNet] Code (Issue #1)

Hi, I want to konw When the code will be made public in this paper. Looking forward to your reply,thank you very much.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.Message ID: @.***>

DGUnggi commented 1 month ago

I want the pretrained model parameter which is named 'model_15.pth', 'model_50.pth'. Then i write the code like this in test.py `def main():

logs_dir = os.getcwd()
# Build model
print('Loading model ...\n')
model = RatUNet(BasicBlock, 64)
model = model.to(device)
try:
    model.load_state_dict(torch.load('/workspace/RatUNet/model_50.pth', map_location=device))
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
model.eval()`

However, this error ocurred.

root@817071d8d489:/workspace/RatUNet_Model# python test.py 
Loading model ...
Error loading model: invalid load key, 'v'.

Error loading model: invalid load key, 'v'. I think that model file has an error. Can you please solve this? or Can you upload the model again?

Zhanghb1688 commented 1 month ago

I tested the weight parameter file and there was no problem. You took the original models file and added a line of code to the test.py file: from models import BasicBlock, RatUNet

DGUnggi commented 1 month ago

Thank you for your answering, but i already did it like this code. Is this code wrong?

'

import cv2

import os import argparse import glob

import numpy as np

import torch

import utils_image as util

from torch.autograd import Variable from models import BasicBlock, RatUNet from utils import batch_PSNR, batch_ssim import torch.backends.cudnn as cudnn

from data import Dataset1

from PIL import Image from torchvision import transforms

import pytorch_ssim

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description="RatUNet_Test") parser.add_argument("--logdir", type=str, default="RatUNet/model_15.pth", help='path of log files') parser.add_argument("--test_data", type=str, default='Zero-DiDCE/DiDCE/DiDCE_code/Bddugrp-1/result/zero-didce', help='test on Set12 or Set68 Urban100 CBSD68 McMaster') parser.add_argument("--test_noiseL", type=float, default=50, help='noise level used on test set') opt = parser.parse_args()

def main():

logs_dir = os.getcwd()
# Build model
print('Loading model ...\n')
model = RatUNet(BasicBlock, 64)
model = model.to(device)
try:
    model.load_state_dict(torch.load('/workspace/RatUNet/model_50.pth', map_location=device))
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
model.eval()
# load data info
print('Loading data info ...\n')
files_source = glob.glob(os.path.join(logs_dir, opt.test_data, '*')) 
files_source.sort()
# process data
psnr_val = 0
ssim_val = 0
for f in files_source:
    # image
    label = Image.open(f).convert('L')
    box = (label.size[1] -  label.size[1]%8, label.size[0] -  label.size[0]%8)
    label = transforms.RandomCrop(box)(label)
    label = transforms.ToTensor()(label)
    img_val = torch.unsqueeze(label, 0)

    torch.manual_seed(64)
    noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=opt.test_noiseL/255.)

    imgn_val = img_val + noise

    img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())

    with torch.no_grad():
        out = model(imgn_val)#, noise1)
        out_val = torch.clamp(out, 0., 1.)
    psnr = batch_PSNR(out_val, img_val, 1.)    
    ssim = batch_ssim(out_val, img_val, 1.)

    psnr_val += psnr
    ssim_val += ssim
    print("图像文件名: %s, PSNR_val: %.4f  SSIM_val: %.4f" % (f, psnr, ssim))    
psnr_val /= len(files_source)
ssim_val /= len(files_source)
print("PSNR_val: %.4f  SSIM_val: %.4f" % (psnr_val, ssim_val))

if name == "main": main() '