XPixelGroup / HAT

CVPR2023 - Activating More Pixels in Image Super-Resolution Transformer Arxiv - HAT: Hybrid Attention Transformer for Image Restoration
Apache License 2.0
1.14k stars 134 forks source link

NameError: name 'RealESRGANModel' is not defined in nondist_validation method #101

Open kargibora opened 9 months ago

kargibora commented 9 months ago

Description:

While attempting to run the train.py script, I encountered a NameError issue during the validation process, specifically in the nondist_validation method. The problem is raised when the training script is run with the config train_Real_HAT_GAN_SRx4_finetune_from_mse_model.yml with validation in some steps (but it is mostly not limited to it). The error traceback is shown below:

Traceback (most recent call last):
  File "/home/borakargi/dbsr/hat/hat/train.py", line 19, in <module>
    train_pipeline(root_path)
  ...
  File "/home/borakargi/dbsr/hat/hat/models/realhatgan_model.py", line 188, in nondist_validation
    super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
NameError: name 'RealESRGANModel' is not defined

Notice that there is no import referering to RealESRGANModel in the script hat/models/realhatgan_model.py:

import numpy as np
import random
import torch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F

Steps to Reproduce:

Run the train.py script. Observe the error during the validation process.

Expected Behavior:

The script should proceed through the validation process without any issues.

Actual Behavior:

The script crashes and throws a NameError.

Suggested Fix:

Replace RealESRGANModel with RealHATGANModel on line 188 in realhatgan_model.py to resolve the issue.

        # super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
        super(RealHATGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
chxy95 commented 9 months ago

The error has been fixed.