chaofengc / IQA-PyTorch

👁️ 🖼️ 🔥PyTorch Toolbox for Image Quality Assessment, including LPIPS, FID, NIQE, NRQM(Ma), MUSIQ, TOPIQ, NIMA, DBCNN, BRISQUE, PI and more...
https://iqa-pytorch.readthedocs.io/
Other
1.81k stars 163 forks source link

Unstability of NIQE #23

Closed wyf0912 closed 2 years ago

wyf0912 commented 2 years ago

Hi Chaofeng,

Thanks for your great work. I meet some problems when using the NIQE metric. It seems that the NIQE results evaluated on JPG and PNG have a large difference and are different from Matlab.

In addition, the NIQE values of the same image evaluated on CPU and GPU also has a large difference.

import pyiqa
from PIL import Image
import torchvision.transforms.functional as TF
import torch
from torchvision.utils import save_image

niqe_metric = pyiqa.create_metric('niqe')

# img_jpg = TF.to_tensor(Image.open("37.jpg"))

img_tensor = torch.load("output_img.pt",map_location=torch.device('cpu'))
save_image(img_tensor, "37.png")
save_image(img_tensor, "37.jpg")

img_png = TF.to_tensor(Image.open("37.png"))
img_jpg = TF.to_tensor(Image.open("37.jpg"))

print(niqe_metric(img_tensor))
print(niqe_metric(img_jpg.unsqueeze(0)))
print(niqe_metric(img_png.unsqueeze(0)))

The output is

tensor(3.1253)
tensor(2.2823)
tensor(3.1208)

I also test the NIQE of saved JPG and PNG images using MATLAB, and the NIQE are 8.3841 and 8.6096.

JPG image PNG image

In addition, it seems that the results of the evaluation on CPU and GPU also have a huge difference.

niqe_metric(img_jpg.unsqueeze(0).cuda())
tensor(4.9318, device='cuda:0')
niqe_metric(img_jpg.unsqueeze(0))
tensor(2.9255)

niqe_metric(img_png.unsqueeze(0))
tensor(3.8962)
niqe_metric(img_png.unsqueeze(0).cuda())
tensor(11.9320, device='cuda:0')

Do you have any advice on it? Thank you!

chaofengc commented 2 years ago

Thank you so much for helping us to find the issue. I have fixed most of them. Here are the explanations:

Matlab NIQE v.s. the fixed results

Why previous version are so different ?

The images you provided is a corner case in which large regions (>96x96) are completely white. For the Matlab codes, they simply ignore such regions. But we do not ignore these regions in our previous version, and occasionally take these regions as good regions. This makes our results much better (lower values) than Matlab results.

The previous version works well for images without such plain regions. And the latest version does the same as Matlab and ignore these plain regions.

https://github.com/chaofengc/IQA-PyTorch/blob/18756e657f0427f2affd5ff4920408520eff399a/pyiqa/archs/niqe_arch.py#L143-L146

Why different results with CPU and GPU ?

This is because of the data type. NIQE requires double float, (i.e. float 64) while previous version use float64 for cpu and float32 for GPU, which is the default type of cuda tensor in pytorch.

All data types are cast to float64 now. It produces the same results for GPU and CPU.

https://github.com/chaofengc/IQA-PyTorch/blob/18756e657f0427f2affd5ff4920408520eff399a/pyiqa/archs/niqe_arch.py#L189-L190

Difference between JPG and PNG

As you know, JPG and PNG have different values, leading to different results. The same applies to matlab.

wyf0912 commented 2 years ago

Thanks a lot for your prompt reply! It seems that it works well now