Closed thisisqiaoqiao closed 7 months ago
import torch import torch.nn as nn import torch.optim import torch.nn.functional as F
import os import argparse import numpy as np import random from torchvision.models import vgg16 from torch.utils.data import DataLoader
from dataloader.data_loader_exposure import exposure_loader from net.model import Dynamic_CS_1D
from IQA_pytorch import SSIM from utils import PSNR, validation, LossNetwork
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default=1) parser.add_argument('--img_path', type=str, default="/data/unagi0/cui_data/light_dataset/Exposure_CVPR21/train/INPUT_IMAGES") parser.add_argument('--img_val_path', type=str, default="/data/unagi0/cui_data/light_dataset/Exposure_CVPR21/validation/INPUT_IMAGES") parser.add_argument("--normalize", action="store_true", help="Default not Normalize in exposure training.")
parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-5) # for batch size 4x2=8 parser.add_argument('--weight_decay', type=float, default=0.0002) parser.add_argument('--pretrain_dir', type=str, default=None)
parser.add_argument('--t_range', type=int, default=500)
parser.add_argument('--num_epochs', type=int, default=50) parser.add_argument('--display_iter', type=int, default=50) parser.add_argument('--snapshots_folder', type=str, default="workdirs/snapshots_folder_exposure_1d")
config = parser.parse_args()
print(config) os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id) if not os.path.exists(config.snapshots_folder): os.makedirs(config.snapshots_folder)
model = Dynamic_CS_1D(config.t_range).cuda() if config.pretrain_dir is not None: model.load_state_dict(torch.load(config.pretrain_dir))
train_dataset = exposure_loader(images_path=config.img_path, normalize=config.normalize) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_dataset = exposure_loader(images_path=config.img_val_path, mode='val', normalize=config.normalize) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
vgg_model = vgg16(pretrained=True).features[:16] vgg_model = vgg_model.cuda()
for param in vgg_model.parameters(): param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)
device = next(model.parameters()).device
L1_loss = nn.L1Loss() L1_smooth_loss = F.smooth_l1_loss
loss_network = LossNetwork(vgg_model) loss_network.eval()
ssim = SSIM() psnr = PSNR() ssim_high = 0 psnr_high = 0
model.train()
for epoch in range(config.num_epochs):
print('the epoch is:', epoch)
for iteration, imgs in enumerate(train_loader):
low_img, high_img = imgs[0].cuda(), imgs[1].cuda()
optimizer.zero_grad()
model.train()
enhance_img = model(low_img)
loss = L1_smooth_loss(enhance_img, high_img)+0.04*loss_network(enhance_img, high_img)
loss.backward()
optimizer.step()
scheduler.step()
if ((iteration + 1) % config.display_iter) == 0:
print("Loss at iteration", iteration + 1, ":", loss.item())
model.eval()
PSNR_mean, SSIM_mean = validation(model, val_loader)
with open(config.snapshots_folder + '/log.txt', 'a+') as f:
f.write('epoch' + str(epoch) + ':' + 'the SSIM is' + str(SSIM_mean) + 'the PSNR is' + str(PSNR_mean) + '\n')
if SSIM_mean > ssim_high:
ssim_high = SSIM_mean
print('the highest SSIM value is:', str(ssim_high))
torch.save(model.state_dict(), os.path.join(config.snapshots_folder, "best_Epoch" + '.pth'))
f.close()
Thank you for your reply. I have an additional question. What configuration did you use for training? I followed #35 's solution and found that the speed was slow during training. Can I increase the batch size? This will have an impact on performance.
Sure, you can increase the batchsize if your GPU memory enough, and meanwhile do not forget to adjust the learning rate as follow:
Thanks for your reply, I will experiment.
[Exposure Correction],May I ask how to train using a single gpu?