yulunzhang / RCAN

PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"
1.37k stars 311 forks source link

what's the meaning of the mean_shift #60

Open ninesun127 opened 5 years ago

ninesun127 commented 5 years ago

1.I have test the mean value of the DIV2K HR trainset


mean=[0.4485, 0.4375, 0.4045] std=[0.2436, 0.2330, 0.2424] the mean is similar to yours, while the std is totally different so how do you calculate the std?


2.My own trainset's mean=[0.5164, 0.5179, 0.4987],std=[0.2256, 0.2194, 0.2282]


I set the mean in my own trainset and std is [1.0,1.0,1.0]

when i add the mean_shift,the network can not work well image

when i remove the mean_shift layers,the network works well image

so,what's the mean of the mean_shift?and why network has a so bad result when i add it?

and my code of calculate the mean and std is in below:

img_list=sorted([os.path.join(dir,x) for x in glob.glob(dir+'*H.png')])

print(len(img_list))
class MyDataset(Dataset):
    def __init__(self,img_list):
        self.data =img_list

    def __getitem__(self, index):
        #x = self.data[index]
        img=self.data[index]

        return ToTensor()(Image.open(img))

    def __len__(self):
        return len(self.data)

dataset = MyDataset(img_list)
loader = DataLoader(
    dataset,
    batch_size=1,
    num_workers=1,
    shuffle=False
)

mean = 0.
std = 0.
nb_samples = 0.
i=0
for data in tqdm(loader):
    #print(type(data))
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    nb_samples += batch_samples
    i=i+1
mean /= nb_samples
std /= nb_samples

print(i,mean,std)
HenryJunW commented 5 years ago

I didn't know why that happened, but I think you can refer to this, https://github.com/thstkdgus35/EDSR-PyTorch/issues/94