Closed shawnricecake closed 2 years ago
Hi,
Thanks for your sharing of the code for BANet.
I would like to ask that how can we evaluate the model? I mean how can we get the PSNR of the model?
I try to evaluate the model you gave for BANet on GoPro datasets, and I got PSNR: 31.7 and SSIM: 0.846 which is lower than those in the paper.
By the way, I evaluate with the code in the following:
import numpy as np import torch import tqdm import yaml from functools import partial from torch.utils.data import DataLoader from dataset import PairedDataset from joblib import cpu_count from models.models import get_model from models.networks import get_nets def validate(path): with open('config/config.yaml', 'r') as f: config = yaml.load(f) batch_size = config.pop('batch_size') get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False) datasets = map(config.pop, ('train', 'val')) datasets = map(PairedDataset.from_config, datasets) train, val = map(get_dataloader, datasets) model = get_model(config['model']) netG = get_nets(config['model']) training_state = (torch.load(path)) new_weight = netG.state_dict() new_weight.update(training_state) netG.load_state_dict(new_weight) epoch_size = config.get('val_batches_per_epoch') or len(val) tq = tqdm.tqdm(val, total=epoch_size) tq.set_description('Validation') i = 0 r_psnr = [] r_ssim = [] for data in tq: with torch.no_grad(): inputs, targets = model.get_input(data) outputs = netG(inputs) curr_psnr, curr_ssim, img_for_vis = model.get_images_and_metrics(inputs, outputs, targets) r_psnr.append(curr_psnr) r_ssim.append(curr_ssim) i += 1 if i > epoch_size: break tq.close() print('PSNR: {:.3f}'.format(np.mean(r_psnr))) print('SSIM: {:.3f}'.format(np.mean(r_ssim))) path = 'checkpoints/BANet_GoPro.pth' validate(path)
If you can share the evaluation code, it would be very nice of you!
Thanks
I solved this problem with the original GoPro dataset
Hi,
Thanks for your sharing of the code for BANet.
I would like to ask that how can we evaluate the model? I mean how can we get the PSNR of the model?
I try to evaluate the model you gave for BANet on GoPro datasets, and I got PSNR: 31.7 and SSIM: 0.846 which is lower than those in the paper.
By the way, I evaluate with the code in the following:
If you can share the evaluation code, it would be very nice of you!
Thanks