fqnchina / IntrinsicImage

Implementation codes of CVPR 2018 Oral paper "Revisiting Deep Intrinsic Image Decompositions"
61 stars 16 forks source link

Pytorch WHDRHingeLossPara BackPropogate question? #3

Open FangYang970206 opened 4 years ago

FangYang970206 commented 4 years ago

Hi, Fan! I implement a WHDRHingeLossPara class in PyTorch version, the code is followed,

import torch
from torch.autograd import Function
import math

class WHDRHingeLossPara(Function):
    def __init__(self, delta, epsilon, device):
        self.delta = delta
        self.epsilon = epsilon
        self.device = device

    def forward(self, inp, targetFile):
        self.inp = inp
        self.targetFile = targetFile
        _, _, height, width = inp.size()
        self.whdr = 0
        self.weight = 0
        with open(targetFile) as f:
            for line in f.readlines():
                strs = line.split(',')
                self.weight += float(strs[0])
                point1_x = math.floor(width * float(strs[2]))
                point1_y = math.floor(height * float(strs[3]))
                point2_x = math.floor(width * float(strs[4]))
                point2_y = math.floor(height * float(strs[5]))
                ratio = inp[0][0][point1_y][point1_x] / (inp[0][0][point2_y][point2_x] + 1e-7)
                predict_j = -1
                if ratio > (1 + self.delta):
                    predict_j = 2
                elif ratio < 1/(1 + self.delta):
                    predict_j = 1
                else:
                    predict_j = 0

                if int(strs[1]) != predict_j:
                    self.whdr += float(strs[0])

        self.whdr = self.whdr / self.weight
        return self.whdr

    def backward(self):
        _, _, height, width = self.inp.size()
        self.grad_input = torch.zeros_like(self.inp).to(self.device)
        with open(self.targetFile) as f:
            for line in f.readlines():
                strs = line.split(',')
                point1_x = math.floor(width * float(strs[2]))
                point1_y = math.floor(height * float(strs[3]))
                point2_x = math.floor(width * float(strs[4]))
                point2_y = math.floor(height * float(strs[5]))
                ratio = self.inp[0][0][point1_y][point1_x] / (self.inp[0][0][point2_y][point2_x] + 1e-7)
                if int(strs[1]) == 0:
                    if ratio < 1 / (1 + self.delta - self.epsilon):
                        self.grad_input[:, :, point1_y, point1_x] -= 1 / (self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                        self.grad_input[:, :, point2_y, point2_x] += self.inp[0][0][point1_y][point1_x] / (self.inp[0][0][point2_y][point2_x] * self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                    elif ratio > (1 + self.delta - self.epsilon):
                        self.grad_input[:, :, point1_y, point1_x] += 1 / (self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                        self.grad_input[:, :, point2_y, point2_x] -= self.inp[0][0][point1_y][point1_x] / (self.inp[0][0][point2_y][point2_x] * self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                elif int(strs[1]) == 1:
                    if ratio > 1 / (1 + self.delta + self.epsilon):
                        self.grad_input[:, :, point1_y, point1_x] += 1 / (self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                        self.grad_input[:, :, point2_y, point2_x] -= self.inp[0][0][point1_y][point1_x] / (self.inp[0][0][point2_y][point2_x] * self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                else:
                    if ratio < (1 + self.delta + self.epsilon):
                        self.grad_input[:, :, point1_y, point1_x] -= 1 / (self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
                        self.grad_input[:, :, point2_y, point2_x] += self.inp[0][0][point1_y][point1_x] / (self.inp[0][0][point2_y][point2_x] * self.inp[0][0][point2_y][point2_x] + 1e-7) * float(strs[0])
        self.grad_input = self.grad_input / self.weight
        return self.grad_input

The new version is same to your Lua version, but in Pytorch, I use the similar code, i find there is no model.backward method, so the gradient of model can't update.

loss = self.whdr_loss.forward(pred, label_txt)
grad = self.whdr_loss.backward()
# model.backward(inp, grad)

Do you have any idea? Thank you very much!

Lee-abcde commented 3 years ago

https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html

Hi bro, You can have a look at this website and imitate the way to identify a loss function I guess it will work in this way.