JiahuiYu / wdsr_ntire2018

Code of our winning entry to NTIRE super-resolution challenge, CVPR 2018
http://www.vision.ee.ethz.ch/ntire18/
598 stars 123 forks source link

in forward of wdsr.py, where dose 127.5 come from? #44

Closed ALLinLLM closed 4 years ago

ALLinLLM commented 4 years ago

In #40, you said you are doing data normaliztion. but i have no idea where 127.5 come from.

as far as i know, data normalization is something like y = (x - μ)/σ, where μ is mean, and σ is the standard viarence. I guess [r, g, b]_means you used here may be the avg value for DIV2K_train_LR_bicubic:

R - 0.4488, G - 0.4371, B - 0.4040,

However, in forward, you used x = (x - self.rgb_mean.cuda()*255)/127.5 so i am confused where the σ = 127.5 come from? it should be [r, g, b]_σ but here is only one σ

And another question is which dataset to use when calculate mean RGB, HR or LRx2, x3, x4. You said in the Experimental Results of your paper that you subtracted the mean RGB of DIV2K training set. I have tested that the mean value is different between HR and LRx2,x3,x4.

ALLinLLM commented 4 years ago
import torch
import time

import utility
import data
import model
import loss
from option import args
from trainer import Trainer

loader = data.Data(args)
lr_r_mean = 0
lr_g_mean = 0
lr_b_mean = 0
lr_pixels = 0
hr_r_mean = 0
hr_g_mean = 0
hr_b_mean = 0
hr_pixels = 0

loader = loader.loader_train
print(len(loader))
start = time.time()

for batch, (lr, hr, _, filename) in enumerate(loader):
    lr = lr.squeeze()
    hr = hr.squeeze()
    lr_pixels += lr.shape[1] * lr.shape[2]
    hr_pixels += hr.shape[1] * hr.shape[2]
    # hr
    r, g, b = lr[0], lr[1], lr[2]
    lr_r_mean += torch.sum(r)
    lr_g_mean += torch.sum(g)
    lr_b_mean += torch.sum(b)
    # lr
    r, g, b = hr[0], hr[1], hr[2]
    hr_r_mean += torch.sum(r)
    hr_g_mean += torch.sum(g)
    hr_b_mean += torch.sum(b)
print(time.time() - start)
lr_r_mean /=  lr_pixels
lr_g_mean /=  lr_pixels
lr_b_mean /=  lr_pixels
hr_r_mean /=  hr_pixels
hr_g_mean /=  hr_pixels
hr_b_mean /=  hr_pixels

lr_r_std = 0
lr_g_std = 0
lr_b_std = 0
lr_pixels = 0
hr_r_std = 0
hr_g_std = 0
hr_b_std = 0
hr_pixels = 0
start = time.time()
for batch, (lr, hr, _, filename) in enumerate(loader):
    lr = lr.squeeze()
    hr = hr.squeeze()
    lr_pixels += lr.shape[1] * lr.shape[2]
    hr_pixels += hr.shape[1] * hr.shape[2]
    # lr
    r, g, b = lr[0], lr[1], lr[2]
    lr_r_std += torch.sum((r-lr_r_mean)**2)
    lr_g_std += torch.sum((g-lr_g_mean)**2)
    lr_b_std += torch.sum((b-lr_b_mean)**2)
    # hr
    r, g, b = hr[0], hr[1], hr[2]
    hr_r_std += torch.sum((r-hr_r_mean)**2)
    hr_g_std += torch.sum((g-hr_g_mean)**2)
    hr_b_std += torch.sum((b-hr_b_mean)**2)
print(time.time() - start)
#
lr_r_std = (lr_r_std / lr_pixels) ** 0.5
lr_g_std = (lr_g_std / lr_pixels) ** 0.5
lr_b_std = (lr_b_std / lr_pixels) ** 0.5
#
hr_r_std = (hr_r_std / hr_pixels) ** 0.5
hr_g_std = (hr_g_std / hr_pixels) ** 0.5
hr_b_std = (hr_b_std / hr_pixels) ** 0.5
print('lr r,g,b mean std:', filename, lr_r_mean, lr_g_mean, lr_b_mean, lr_r_std, lr_g_std, lr_b_std)
print('hr r,g,b mean std:', filename, hr_r_mean, hr_g_mean, hr_b_mean, hr_r_std, hr_g_std, hr_b_std)
ALLinLLM commented 4 years ago

i use cmd /workdir/EDSR-PyTorch/src/div2k_statistic.py --n_threads 0 --epochs 1 --batch_size 1 --dir_data /sr_data/datasets/super_resolution/ --data_train DIV2K --data_test "" --data_range 1-800/801-900 --scale 2 and result are following:

lr r,g,b mean std: ['0564x2'] tensor(116.5574) tensor(110.7386) tensor(101.0412) tensor(70.3525) tensor(66.0828) tensor(71.0180)
hr r,g,b mean std: ['0564x2'] tensor(116.5522) tensor(110.7315) tensor(101.0332) tensor(71.3605) tensor(67.1300) tensor(71.9023)

i use only the 1-800 of DIV2K, and the avg: R_mean = 116.5574 G_mean = 110.7386 B_mean = 101.0412 which is different from the code in your WDSR_a.py R_mean = 0.4488x255 = 114.444‬ G_mean = 0.4371x255 =111.4605 B_mean = 0.4040x255 =103.02

for the std, the difference is more obvious: 127.5 vs. 70.3525(R_std) 66.0828(G_std) 71.0180(B_std)

rongleiji commented 4 years ago

Hi, I think it is the total std of 0-255 , not of each channel.

ALLinLLM commented 4 years ago

Hi, I think it is the total std of 0-255 , not of each channel.

got it, i will try and compare between using total std and per_channel std