Paper99 / SRFBN_CVPR19

Pytorch code for our paper "Feedback Network for Image Super-Resolution" (CVPR2019)
MIT License
551 stars 126 forks source link

An error in your code. #45

Open penguinbing opened 5 years ago

penguinbing commented 5 years ago

The MeanShift Conv2d you defined in your code won't freeze, there is no requires_grad attribute in Conv2d module. You should freeze it using self.weight.requires_grad=False and self.bias.requires_grad=False instead, rather than self.requires_grad=False. However, change MeanShift Conv2d to below will be more clear and more correct. The code is following here

class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False
Paper99 commented 5 years ago

Thank you for reminding me of this mistake. I haven't noticed this yet. I will correct it as soon as possible.

zinhoo commented 5 years ago

Be careful with your English: ‘a error‘ should be ‘An error’

penguinbing commented 5 years ago

Thank you. @zinhoo