pp00704831 / BANet-TIP-2022

35 stars 2 forks source link

Evaluation Problem #1

Closed shawnricecake closed 2 years ago

shawnricecake commented 2 years ago

Hi,

Thanks for your sharing of the code for BANet.

I would like to ask that how can we evaluate the model? I mean how can we get the PSNR of the model?

I try to evaluate the model you gave for BANet on GoPro datasets, and I got PSNR: 31.7 and SSIM: 0.846 which is lower than those in the paper.

By the way, I evaluate with the code in the following:

import numpy as np
import torch
import tqdm
import yaml
from functools import partial
from torch.utils.data import DataLoader
from dataset import PairedDataset
from joblib import cpu_count
from models.models import get_model
from models.networks import get_nets

def validate(path):

    with open('config/config.yaml', 'r') as f:
        config = yaml.load(f)

    batch_size = config.pop('batch_size')
    get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)

    datasets = map(config.pop, ('train', 'val'))
    datasets = map(PairedDataset.from_config, datasets)
    train, val = map(get_dataloader, datasets)

    model = get_model(config['model'])
    netG = get_nets(config['model'])

    training_state = (torch.load(path))
    new_weight = netG.state_dict()
    new_weight.update(training_state)
    netG.load_state_dict(new_weight)

    epoch_size = config.get('val_batches_per_epoch') or len(val)
    tq = tqdm.tqdm(val, total=epoch_size)
    tq.set_description('Validation')
    i = 0
    r_psnr = []
    r_ssim = []
    for data in tq:
        with torch.no_grad():
            inputs, targets = model.get_input(data)
            outputs = netG(inputs)
            curr_psnr, curr_ssim, img_for_vis = model.get_images_and_metrics(inputs, outputs, targets)
            r_psnr.append(curr_psnr)
            r_ssim.append(curr_ssim)
            i += 1
            if i > epoch_size:
                break
    tq.close()

    print('PSNR: {:.3f}'.format(np.mean(r_psnr)))
    print('SSIM: {:.3f}'.format(np.mean(r_ssim)))

path = 'checkpoints/BANet_GoPro.pth'

validate(path)

If you can share the evaluation code, it would be very nice of you!

Thanks

shawnricecake commented 2 years ago

I solved this problem with the original GoPro dataset