Closed nnmaitian closed 3 months ago
是否有pytorch版本的gram loss
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
from torch.nn.parallel import DistributedDataParallel as DDP
device = torch.device("cuda")
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
"""norm (bool): normalize/denormalize the stats"""
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self):
super(VGGPerceptualLoss, self).__init__()
blocks = []
pretrained = True
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
for param in self.parameters():
param.requires_grad = False
def forward(self, X, Y, indices=None):
X = self.normalize(X)
Y = self.normalize(Y)
indices = [2, 7, 12, 21, 30]
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
k = 0
loss0 = 0
loss1 = 0
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
Y = self.vgg_pretrained_features[i](Y)
if (i+1) in indices:
n, c, h, w = X.shape
pX = X.reshape(n, c, h*w)
pY = Y.reshape(n, c, h*w)
gX = torch.matmul(pX, pX.transpose(-2, -1)) / (h * w)
gY = torch.matmul(pY, pY.transpose(-2, -1)) / (h * w)
loss0 += weights[k] * (X - Y.detach()).abs().mean()
loss1 += weights[k] * ((gX - gY.detach()) ** 2).mean()
k += 1
return loss0, loss1
万分感谢,发现您4.22版本去掉了该loss,新增了基于sobel计算的smooth loss,想请教一下这个是什么原因?
万分感谢,发现您4.22版本去掉了该loss,新增了基于sobel计算的smooth loss,想请教一下这个是什么原因?
你好,推理代码不是训练代码,推理代码一直以来很少改动以免引入问题
您好,再想请教一下关于gram loss的内部参数,原始代码为weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5],您改成了weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],这个是有什么深意吗?以及想请教下,你是如何设置不同loss的权重超参,这个您方便公开具体参数吗?
@nnmaitian [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 是网上别人那来的祖传代码:https://github.com/ceciliavision/perceptual-reflection-removal/blob/92e28441922a27697d57456bfef96dc2bc9e056a/main.py#L240 我试过不加权就会比较差
我改成了一个看起来比较顺眼的样子,这部分我没有做过细致实验 我记得只要是最深层的权重比较大,其它层小一点,效果都还可以
OK 感谢解答
你好,可以麻烦提供一下4.17版本中,gram loss相关的loss代码吗