cuiziteng / Illumination-Adaptive-Transformer

[BMVC 2022] You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction. SOTA for low light enhancement, 0.004 seconds try this for pre-processing.
Apache License 2.0
441 stars 43 forks source link

Single gpu training #64

Closed thisisqiaoqiao closed 7 months ago

thisisqiaoqiao commented 7 months ago

[Exposure Correction],May I ask how to train using a single gpu?

cuiziteng commented 7 months ago

import torch import torch.nn as nn import torch.optim import torch.nn.functional as F

import os import argparse import numpy as np import random from torchvision.models import vgg16 from torch.utils.data import DataLoader

from dataloader.data_loader_exposure import exposure_loader from net.model import Dynamic_CS_1D

from IQA_pytorch import SSIM from utils import PSNR, validation, LossNetwork

parser = argparse.ArgumentParser()

parser.add_argument('--gpu_id', type=str, default=1) parser.add_argument('--img_path', type=str, default="/data/unagi0/cui_data/light_dataset/Exposure_CVPR21/train/INPUT_IMAGES") parser.add_argument('--img_val_path', type=str, default="/data/unagi0/cui_data/light_dataset/Exposure_CVPR21/validation/INPUT_IMAGES") parser.add_argument("--normalize", action="store_true", help="Default not Normalize in exposure training.")

parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-5) # for batch size 4x2=8 parser.add_argument('--weight_decay', type=float, default=0.0002) parser.add_argument('--pretrain_dir', type=str, default=None)

parser.add_argument('--t_range', type=int, default=500)

parser.add_argument('--num_epochs', type=int, default=50) parser.add_argument('--display_iter', type=int, default=50) parser.add_argument('--snapshots_folder', type=str, default="workdirs/snapshots_folder_exposure_1d")

config = parser.parse_args()

print(config) os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id) if not os.path.exists(config.snapshots_folder): os.makedirs(config.snapshots_folder)

model = Dynamic_CS_1D(config.t_range).cuda() if config.pretrain_dir is not None: model.load_state_dict(torch.load(config.pretrain_dir))

train_dataset = exposure_loader(images_path=config.img_path, normalize=config.normalize) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

val_dataset = exposure_loader(images_path=config.img_val_path, mode='val', normalize=config.normalize) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

vgg_model = vgg16(pretrained=True).features[:16] vgg_model = vgg_model.cuda()

for param in vgg_model.parameters(): param.requires_grad = False

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)

device = next(model.parameters()).device

L1_loss = nn.L1Loss() L1_smooth_loss = F.smooth_l1_loss

loss_network = LossNetwork(vgg_model) loss_network.eval()

ssim = SSIM() psnr = PSNR() ssim_high = 0 psnr_high = 0

model.train()

for epoch in range(config.num_epochs):

print('the epoch is:', epoch)

for iteration, imgs in enumerate(train_loader):
    low_img, high_img = imgs[0].cuda(), imgs[1].cuda()

    optimizer.zero_grad()
    model.train()
    enhance_img = model(low_img)

    loss = L1_smooth_loss(enhance_img, high_img)+0.04*loss_network(enhance_img, high_img)

    loss.backward()

    optimizer.step()
    scheduler.step()

    if ((iteration + 1) % config.display_iter) == 0:
        print("Loss at iteration", iteration + 1, ":", loss.item())

model.eval()

PSNR_mean, SSIM_mean = validation(model, val_loader)

with open(config.snapshots_folder + '/log.txt', 'a+') as f:
    f.write('epoch' + str(epoch) + ':' + 'the SSIM is' + str(SSIM_mean) + 'the PSNR is' + str(PSNR_mean) + '\n')

if SSIM_mean > ssim_high:
    ssim_high = SSIM_mean
    print('the highest SSIM value is:', str(ssim_high))
    torch.save(model.state_dict(), os.path.join(config.snapshots_folder, "best_Epoch" + '.pth'))

f.close()
thisisqiaoqiao commented 7 months ago

Thank you for your reply. I have an additional question. What configuration did you use for training? I followed #35 's solution and found that the speed was slow during training. Can I increase the batch size? This will have an impact on performance.

cuiziteng commented 7 months ago

Sure, you can increase the batchsize if your GPU memory enough, and meanwhile do not forget to adjust the learning rate as follow:

image
thisisqiaoqiao commented 7 months ago

Thanks for your reply, I will experiment.