Open jly0810 opened 9 months ago
A demo
import torch
import torch.nn as nn
class VGG19(nn.Module):
def __init__(self):
super(VGG19, self).__init__()
self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
self.relu1_1 = nn.ReLU(inplace=True)
# 224 x 224
self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
self.relu1_2 = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
# 112 x 112
self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
self.relu2_1 = nn.ReLU(inplace=True)
self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
self.relu2_2 = nn.ReLU(inplace=True)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
# 56 x 56
self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
self.relu3_1 = nn.ReLU(inplace=True)
self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
self.relu3_2 = nn.ReLU(inplace=True)
self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
self.relu3_3 = nn.ReLU(inplace=True)
self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
self.relu3_4 = nn.ReLU(inplace=True)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
# 28 x 28
self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
self.relu4_1 = nn.ReLU(inplace=True)
self.pad4_2 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 0)
self.relu4_2 = nn.ReLU(inplace=True)
self.pad4_3 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 0)
self.relu4_3 = nn.ReLU(inplace=True)
self.pad4_4 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv4_4 = nn.Conv2d(512, 512, 3, 1, 0)
self.relu4_4 = nn.ReLU(inplace=True)
self.maxPool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
# 14 x 14
self.pad5_1 = nn.ReflectionPad2d((1, 1, 1, 1))
self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 0)
self.relu5_1 = nn.ReLU(inplace=True)
# 14 x 14
def forward(self, x):
out = self.conv0(x)
out = self.pad1_1(out)
out = self.conv1_1(out)
out = self.relu1_1(out)
out1 = out
out = self.pad1_2(out)
out = self.conv1_2(out)
pool1 = self.relu1_2(out)
out, _ = self.maxpool1(pool1)
out = self.pad2_1(out)
out = self.conv2_1(out)
out = self.relu2_1(out)
out2 = out
out = self.pad2_2(out)
out = self.conv2_2(out)
pool2 = self.relu2_2(out)
out, _ = self.maxpool2(pool2)
out = self.pad3_1(out)
out = self.conv3_1(out)
out = self.relu3_1(out)
out3 = out
out = self.pad3_2(out)
out = self.conv3_2(out)
out = self.relu3_2(out)
out = self.pad3_3(out)
out = self.conv3_3(out)
out = self.relu3_3(out)
out = self.pad3_4(out)
out = self.conv3_4(out)
pool3 = self.relu3_4(out)
out, _ = self.maxpool3(pool3)
out = self.pad4_1(out)
out = self.conv4_1(out)
out = self.relu4_1(out)
out4 = out
out = self.pad4_2(out)
out = self.conv4_2(out)
out = self.relu4_2(out)
out = self.pad4_3(out)
out = self.conv4_3(out)
out = self.relu4_3(out)
out = self.pad4_4(out)
out = self.conv4_4(out)
pool4 = self.relu4_4(out)
out, _ = self.maxPool4(pool4)
out = self.pad5_1(out)
out = self.conv5_1(out)
out = self.relu5_1(out)
return [out1, out2, out3, out4, out]
def compute_style_loss(model, img, style):
img_fea = model(img)
sty_fea = model(style)
loss_style = []
for j in range(5):
B, N, cH, cW = img_fea[j].size()
fea_A = img_fea[j].reshape(B, N, -1)
gram_A = torch.matmul(fea_A, fea_A.transpose(1, 2)) / (cH * cW)
B, N, sH, sW = sty_fea[j].size()
fea_B = sty_fea[j].reshape(B, N, -1)
gram_B = torch.matmul(fea_B, fea_B.transpose(1, 2)) / (sH * sW)
loss_style.append(torch.mean((gram_A - gram_B) ** 2)) # MSE
loss_style = sum(loss_style)/len(loss_style)
return loss_style
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg19 = VGG19()
vgg19.load_state_dict(torch.load("vgg_normalised_conv5_1.pth")) # part of vgg_normalised.pth Link: https://drive.google.com/file/d/1yXZwbbJCEamrPFSTKz-2PND3uSayhNAA/view?usp=drive_link
vgg19.to(device)
vgg19.eval()
for param in vgg19.parameters():
param.requires_grad = False
x = torch.rand(1, 3, 256, 256).to(device)
y = torch.rand(1, 3, 128, 128).to(device)
ls = compute_style_loss(vgg19, x, y)
print(ls)
Hello, I am very interested in your work! I see that one of your quantitative indicators is Gram Loss. Can you share your calculation code? I referred to it https://github.com/ProGamerGov/neural-style-pt/tree/master But the calculated value is particularly large, I hope you can help me, thank you!