Closed Susango closed 1 year ago
您好,请参照一下 IAT_enhance/train_lol_v2.py ,把tran_exposure里面的DistributedSampler这些更换去掉即可。
您好,请参照一下 IAT_enhance/train_lol_v2.py ,把tran_exposure里面的DistributedSampler这些更换去掉即可。
我修改了代码如下,之后无论怎么输入数据都能训练,初步猜测是数据没有输入进去。请问可以帮我看一下修改的代码是否正确吗?感恩的心❤
`import torch import torch.nn as nn import torch.optim import torch.nn.functional as F
import os import argparse import numpy as np import random from torchvision.models import vgg16 from torch.utils.data import DataLoader
from torch import distributed as dist
from data_loaders.exposure import exposure_loader from model.IAT_main import IAT
from IQA_pytorch import SSIM from utils import PSNR, validation, LossNetwork, visualization, get_dist_info
print(torch.cuda.device_count())
parser = argparse.ArgumentParser() parser.add_argument('--gpu_id', type=str, default=1)
parser.add_argument('--img_path', type=str, default="/home/s3090/zzh/Illumination-Adaptive-Transformer-main/IAT_enhance/Your_Path/train/INPUT_IMAGES") parser.add_argument('--img_val_path', type=str, default="/home/s3090/zzh/Illumination-Adaptive-Transformer-main/IAT_enhance/Your_Path/validation/INPUT_IMAGES") parser.add_argument("--normalize", action="store_true", help="Default not Normalize in exposure training.")
parser.add_argument('--batch_size', type=int, default=2) parser.add_argument('--lr', type=float, default=2e-4) # for batch size 4x2=8 parser.add_argument('--weight_decay', type=float, default=0.0002) parser.add_argument('--pretrain_dir', type=str, default=None)
parser.add_argument('--num_epochs', type=int, default=50) parser.add_argument('--display_iter', type=int, default=100) parser.add_argument('--snapshots_folder', type=str, default="workdirs/snapshots_folder_exposure")
config = parser.parse_args()
print(config) os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)
if not os.path.exists(config.snapshots_folder): os.makedirs(config.snapshots_folder)
model = IAT(type='exp').cuda() if config.pretrain_dir is not None: model.load_state_dict(torch.load(config.pretrain_dir))
train_dataset = exposure_loader(images_path=config.img_path, normalize=config.normalize)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8, pin_memory=True)
val_dataset = exposure_loader(images_path=config.img_val_path, mode='val', normalize=config.normalize) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
vgg_model = vgg16(pretrained=True).features[:16] vgg_model = vgg_model.cuda()
for param in vgg_model.parameters(): param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)
device = next(model.parameters()).device print('the device is:', device)
L1_loss = nn.L1Loss() L1_smooth_loss = F.smooth_l1_loss
loss_network = LossNetwork(vgg_model) loss_network.eval()
ssim = SSIM() psnr = PSNR() ssim_high = 0 psnrhigh = 0 rank, = get_dist_info()
model.train() print('######## Start IAT Training #########') for epoch in range(config.num_epochs):
print('the epoch is:', epoch)
#train_sampler.set_epoch(epoch)
for iteration, imgs in enumerate(train_loader):
low_img, high_img = imgs[0].cuda(), imgs[1].cuda()
# Checking!
#visualization(low_img, 'show/low', iteration)
#visualization(high_img, 'show/high', iteration)
optimizer.zero_grad()
model.train()
mul, add, enhance_img = model(low_img)
loss = L1_loss(enhance_img, high_img)
#loss = L1_smooth_loss(enhance_img, high_img)+0.04*loss_network(enhance_img, high_img)
loss.backward()
optimizer.step()
scheduler.step()
if ((iteration + 1) % config.display_iter) == 0:
print("Loss at iteration", iteration + 1, ":", loss.item())
#print(rank)
# Evaluation Model
if rank == 0:
#print('111')
model.eval()
PSNR_mean, SSIM_mean = validation(model, val_loader)
print(SSIM_mean)
with open(config.snapshots_folder + '/log.txt', 'a+') as f:
f.write('epoch' + str(epoch) + ':' + 'the SSIM is' + str(SSIM_mean) + 'the PSNR is' + str(PSNR_mean) + '\n')
if SSIM_mean > ssim_high:
ssim_high = SSIM_mean
print('the highest SSIM value is:', str(ssim_high))
torch.save(model.state_dict(), os.path.join(config.snapshots_folder, "best_Epoch" + '.pth'))
f.close()
您检查下读到图像了吗,是不是img_path这里后面没有加'/'
请问您解决了吗
我只有单张gpu训练 Exposure Correction的效果时,一直报RuntimeError:ProcessGroupNCCL is only supported with GPUs, no GPUs found. 请问怎么解决。十分感谢!