songzijiang / LGAN

Source codes for LGAN
Apache License 2.0
8 stars 0 forks source link

About test ELAN #2

Open sunsunshark opened 8 months ago

sunsunshark commented 8 months ago

Hi ,thank you for your excellent work! I am using the normal ELAN(not the light one) to test my images,and i used your test_custom_code.(Thank you here) However, the yellow tone in the results is very heavy.I wonder if you have obtained the correct inference code for the ordinary ELAN network.I would appreciate it if you could help me import math import argparse import yaml import utils import os from tqdm import tqdm import imageio import torch import torch.nn as nn from multiprocessing import Process, Queue from utils import ndarray2tensor import time

class save_img():

def init(self):

self.n_processes = 2

#

def bg_target(queue):

while True:

if not queue.empty():

filename, tensor = queue.get()

if filename is None:

break

imageio.imwrite(filename, tensor.numpy())

#

def begin_background(self):

self.queue = Queue()

#

self.process = [

Process(target=self.bg_target, args=(self.queue,)) \

for _ in range(self.n_processes)

]

for p in self.process:

p.start()

#

def end_background(self):

for _ in range(self.n_processes):

self.queue.put((None, None))

while not self.queue.empty():

time.sleep(1)

for p in self.process:

p.join()

#

def save_results(self, filename, img):

tensor_cpu = img[0].byte().permute(1, 2, 0).cpu()

self.queue.put((filename, tensor_cpu))

def bg_target(queue): while True: if not queue.empty(): filename, tensor = queue.get() if filename is None: break imageio.imwrite(filename, tensor.numpy())

class save_img(): def init(self): self.n_processes = 32 self.queue = Queue() self.process = [ Process(target=bgtarget, args=(self.queue,)) \ for in range(self.n_processes) ]

def begin_background(self):
    for p in self.process:
        p.start()

def end_background(self):
    for _ in range(self.n_processes):
        self.queue.put((None, None))
    while not self.queue.empty():
        time.sleep(1)
    for p in self.process:
        p.join()

def save_results(self, filename, img):
    tensor_cpu = img[0].byte().permute(1, 2, 0).cpu()
    self.queue.put((filename, tensor_cpu))

if name == 'main': parser = argparse.ArgumentParser(description='config')

parser.add_argument('--config', type=str, default='E:/pycharmcode2/ELAN-main/configs/elan_x4.yml',
                    help='pre-config file for training')
parser.add_argument('--resume', type=str, default=None, help='resume training or not')
parser.add_argument('--custom', type=str, default=None, help='use custom block')
parser.add_argument('--cloudlog', type=str, default=None, help='use cloudlog')
parser.add_argument('--custom_image_path', type=str, default=None, help='path of the custom image')

device = None

args = parser.parse_args()

if args.config:
    opt = vars(args)
    yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    opt.update(yaml_args)

## set visibel gpu
gpu_ids_str = str(args.gpu_ids).replace('[', '').replace(']', '')
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(gpu_ids_str)

## select active gpu devices
device = None
if len(args.gpu_ids) > 0 and torch.cuda.is_available():
    print('use cuda & cudnn for acceleration!')
    print('the gpu id is: {}'.format(args.gpu_ids))
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True
else:
    print('use cpu for training!')
    device = torch.device('cpu')
# torch.set_num_threads(args.threads)

## definitions of model
try:
    model = utils.import_module('models.{}_network'.format(args.model)).create_model(args)
except Exception:
    raise ValueError('not supported model type! or something')
# if args.fp == 16:
#     model.half()

## load pretrain
if args.pretrain is not None:
    print('load pretrained model: {}!'.format(args.pretrain))
    ckpt = torch.load(args.pretrain, map_location=device)
    print(ckpt['model_state_dict'].keys())

    #model.load(ckpt['model_state_dict'])
    model.load_state_dict(ckpt['model_state_dict'],strict=False)

model = nn.DataParallel(model).to(device)

model = model.eval()
torch.set_grad_enabled(False)
save_path = args.log_path
si = save_img()
si.begin_background()

filePath = args.custom_image_path
for filename in tqdm(os.listdir(filePath), ncols=80):
    lr = imageio.imread(filePath + os.sep + filename)
    lr = ndarray2tensor(lr)
    lr = torch.unsqueeze(lr, 0)
    # if args.fp == 16:
    #     lr = lr.type(torch.HalfTensor)
    lr = lr.to(device)
    sr = model(lr)
    # quantize output to [0, 255]
    sr = sr.clamp(0, 255).round()
    path = save_path + os.sep + 'custom' + os.sep
    if not os.path.exists(path):
        os.makedirs(path)
    fileUname, ext = '.'.join(filename.split('.')[:-1]), filename.split('.')[-1]
    path += (fileUname + '_x' + str(args.scale) + '_SR' + '.' + ext)
    si.save_results(path, sr)

si.end_background()

I'm sorry if I bothered you

songzijiang commented 8 months ago

I have tested my images including aerial imagery and medical imagery, it seems nothing wrong. The bias value is decided in the MeanShift, which is defined in the 'm_block.py' in my project. Maybe you can check the file or train your model using your own mean and std.

sunsunshark commented 8 months ago

我已经测试了我的图像,包括航空图像和医学图像,似乎没有错。偏差值是在 MeanShift 中确定的,MeanShift 在我的项目中的“m_block.py”中定义。也许你可以检查文件或使用你自己的平均值和标准来训练你的模型。

Thanks!I will try it again.It's very nice of you.