hzwer / Practical-RIFE

More practical frame interpolation approach.
MIT License
623 stars 68 forks source link

gram loss #88

Closed nnmaitian closed 3 months ago

nnmaitian commented 3 months ago

你好,可以麻烦提供一下4.17版本中,gram loss相关的loss代码吗

nnmaitian commented 3 months ago

是否有pytorch版本的gram loss

hzwer commented 3 months ago
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
from torch.nn.parallel import DistributedDataParallel as DDP

device = torch.device("cuda")

class MeanShift(nn.Conv2d):
    def __init__(self, data_mean, data_std, data_range=1, norm=True):
        """norm (bool): normalize/denormalize the stats"""
        c = len(data_mean)
        super(MeanShift, self).__init__(c, c, kernel_size=1)
        std = torch.Tensor(data_std)
        self.weight.data = torch.eye(c).view(c, c, 1, 1)
        if norm:
            self.weight.data.div_(std.view(c, 1, 1, 1))
            self.bias.data = -1 * data_range * torch.Tensor(data_mean)
            self.bias.data.div_(std)
        else:
            self.weight.data.mul_(std.view(c, 1, 1, 1))
            self.bias.data = data_range * torch.Tensor(data_mean)
        self.requires_grad = False

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        pretrained = True
        self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
        self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, X, Y, indices=None):
        X = self.normalize(X)
        Y = self.normalize(Y)
        indices = [2, 7, 12, 21, 30]
        weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
        k = 0
        loss0 = 0
        loss1 = 0
        for i in range(indices[-1]):
            X = self.vgg_pretrained_features[i](X)
            Y = self.vgg_pretrained_features[i](Y)
            if (i+1) in indices:
                n, c, h, w = X.shape
                pX = X.reshape(n, c, h*w)
                pY = Y.reshape(n, c, h*w)
                gX = torch.matmul(pX, pX.transpose(-2, -1)) / (h * w)
                gY = torch.matmul(pY, pY.transpose(-2, -1)) / (h * w)
                loss0 += weights[k] * (X - Y.detach()).abs().mean()
                loss1 += weights[k] * ((gX - gY.detach()) ** 2).mean()
                k += 1
        return loss0, loss1
nnmaitian commented 3 months ago

万分感谢,发现您4.22版本去掉了该loss,新增了基于sobel计算的smooth loss,想请教一下这个是什么原因?

hzwer commented 3 months ago

万分感谢,发现您4.22版本去掉了该loss,新增了基于sobel计算的smooth loss,想请教一下这个是什么原因?

你好,推理代码不是训练代码,推理代码一直以来很少改动以免引入问题

nnmaitian commented 3 months ago

您好,再想请教一下关于gram loss的内部参数,原始代码为weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5],您改成了weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],这个是有什么深意吗?以及想请教下,你是如何设置不同loss的权重超参,这个您方便公开具体参数吗?

hzwer commented 3 months ago

@nnmaitian [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 是网上别人那来的祖传代码:https://github.com/ceciliavision/perceptual-reflection-removal/blob/92e28441922a27697d57456bfef96dc2bc9e056a/main.py#L240 我试过不加权就会比较差

我改成了一个看起来比较顺眼的样子,这部分我没有做过细致实验 我记得只要是最深层的权重比较大,其它层小一点,效果都还可以

nnmaitian commented 3 months ago

OK 感谢解答