zheng-yuwei / RankIQA.PyTorch

基于PyTorch实现的图像质量评估模型RankIQA
138 stars 28 forks source link

基于PyTorch实现的RankIQA

功能说明:

可用的骨干网络包括:


包含特性


文件结构说明


使用说明

数据准备

在文件夹data下放数据,分成三个文件夹: train/test/val,对应 训练/测试/验证 数据文件夹; 每个子文件夹下,需要根据训练任务的不同放置训练图像(若需要修改数据读取方式等,可查看data/my_dataloader)。

数据准备完毕后,使用utils/check_images.py脚本,检查图像数据的有效性,防止在训练过程中遇到无效图片中止训练。

排序对比损失任务

在每个数据集文件夹根目录下放置标签文件:如训练集data/train/label.txt,每行两张图像的项目相对路径,其中第一个图像质量>第二个图像质量, 内容举例如下:

data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise5/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise11/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise15/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise21/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise0/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise3/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise5/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise7/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise9/carnivaldolls.bmp
...

分布的推土距离损失任务

在每个数据集文件夹根目录下放置标签文件:如训练集data/train/label.txt,每行:一张图像的项目相对路径,这张图的MOS概率分布, 内容举例如下:

data/train/refimgs/carnivaldolls.bmp,0.,0.,0.,0.2,0.8
data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp,0.,0.,0.,0.2,0.8
...

表示 图像data/train/refimgs/carnivaldolls.bmp的评分[1,2,3,4,5]对应的分布为[0.,0.,0.,0.2,0.8]。 注意评分是1:N+1,N=num_classes,这与我损失函数criterions/emd_loss.py的编写相关。

回归损失任务

在每个数据集文件夹根目录下放置标签文件:如训练集data/train/label.txt,每行:一张图像的项目相对路径,这张图的MOS分值 内容举例如下:

data/train/refimgs/carnivaldolls.bmp,7.6
data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp,3.2
...

表示 图像data/train/refimgs/carnivaldolls.bmp的评分为3.2。 在训练时,如果配置num_classes=1,则损失函数计算为:||output - 3.2||; 配置num_classes=N,则损失函数计算为:||(output.softmax() * [1:N+1]).sum() - 3.2||; 详细可参考损失函数criterions/regress_loss.py

部分重要配置参数说明

针对config.py里的部分重要参数说明如下:


快速使用 —— 使用公开数据集AVA(aesthetic visual analysis)进行训练、测试、部署

下载数据mtobeiyf/ava_downloader


使用说明

可参考对应的z_task_shell/*.sh文件

模型信息打印

打印分支数为6、输入图像分辨率400x224efficientnet-b0网络的基本信息:

python main.py --arch efficientnet_b0 --num_classes 6 --image_size 244 224

模型训练

基于data/目录下的train数据集,使用分支数为1、输入图像分辨率244x224、损失函数为排序对比损失ranking loss(+ margin=0.1) 的efficientnet-b0网络,同时加载预训练模型、训练学习率warmup 5个epoch,batch size为384, 数据加载worker为16个,训练65个epoch:

python main.py --data data/ --train --arch efficientnet_b0 --num_classes 1 \
--criterion=rank --margin 0.1 -- --use_margin --image_size 244 224 \
--pretrained --warmup 5 --epochs 65 -b 384 -j 16 --gpus 1 --nodes 1

参数的详细说明可查看config.py文件。

模型评估

基于data/目录下的test数据集,评估checkpoints/model_best_efficientnet-b0.pth目录下的模型 (需要指定模型输入、输出、损失函数、模型结构,数据加载的worker,推理时的batch size):

python main.py --data data -e --arch efficientnet_b0 --num_classes 1 --criterion=rank --margin 0.1 \
--image_size 244 224 --batch_size 5 --workers 0 --resume checkpoints/model_best_efficientnet-b0.pth -g 1 -n 1

参数的详细说明可查看config.py文件。

模型转换

转为jit格式模型文件:

python main.py --jit --arch efficientnet_b0 --num_classes 1 --resume checkpoints/model_best_efficientnet-b0.pth -g 0

模型部署demo

训练好模型后,想用该模型对图像数据进行预测,可使用demos目录下的脚本image_assessment.py

cd demos
python image_assessment.py

数据集

数据集 介绍 备注 网址
TID2013 质量评价 25张参考图像,24个失真类型 TID2013
LIVE 质量评价 29张参考图像,5个失真类型 LIVE
MLIVE 质量评价 15张参考图像,15个失真类型 MLIVE
WATERLOO 质量评价 4744张参考图像,20个失真类型 Waterloo
photo.net 美观评价 20,278张图像,打分[0,10] photo.net
DPChallenge.com 美观评价 16,509张图像,打分[0,10] DPChallenge
AVA 美观评价 255,500张图像,打分[0,10] ava_downloader

Reference

d-li14/mobilenetv3.pytorch

lukemelas/EfficientNet-PyTorch

zhanghang1989/ResNeSt

titu1994/neural-image-assessment

TODO