linfengWen98 / CAP-VSTNet

[CVPR 2023] CAP-VSTNet: Content Affinity Preserved Versatile Style Transfer
MIT License
130 stars 8 forks source link

Gram-loss #16

Open jly0810 opened 9 months ago

jly0810 commented 9 months ago

Hello, I am very interested in your work! I see that one of your quantitative indicators is Gram Loss. Can you share your calculation code? I referred to it https://github.com/ProGamerGov/neural-style-pt/tree/master But the calculated value is particularly large, I hope you can help me, thank you!

linfengWen98 commented 8 months ago

A demo

import torch
import torch.nn as nn

class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
        self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
        self.relu1_1 = nn.ReLU(inplace=True)
        # 224 x 224

        self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # 112 x 112

        self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # 56 x 56

        self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
        self.relu3_2 = nn.ReLU(inplace=True)
        self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
        self.relu3_4 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # 28 x 28

        self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
        self.relu4_1 = nn.ReLU(inplace=True)
        self.pad4_2 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 0)
        self.relu4_2 = nn.ReLU(inplace=True)
        self.pad4_3 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 0)
        self.relu4_3 = nn.ReLU(inplace=True)
        self.pad4_4 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv4_4 = nn.Conv2d(512, 512, 3, 1, 0)
        self.relu4_4 = nn.ReLU(inplace=True)
        self.maxPool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # 14 x 14

        self.pad5_1 = nn.ReflectionPad2d((1, 1, 1, 1))
        self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 0)
        self.relu5_1 = nn.ReLU(inplace=True)
        # 14 x 14

    def forward(self, x):
        out = self.conv0(x)

        out = self.pad1_1(out)
        out = self.conv1_1(out)
        out = self.relu1_1(out)

        out1 = out

        out = self.pad1_2(out)
        out = self.conv1_2(out)
        pool1 = self.relu1_2(out)

        out, _ = self.maxpool1(pool1)

        out = self.pad2_1(out)
        out = self.conv2_1(out)
        out = self.relu2_1(out)

        out2 = out

        out = self.pad2_2(out)
        out = self.conv2_2(out)
        pool2 = self.relu2_2(out)

        out, _ = self.maxpool2(pool2)

        out = self.pad3_1(out)
        out = self.conv3_1(out)
        out = self.relu3_1(out)

        out3 = out

        out = self.pad3_2(out)
        out = self.conv3_2(out)
        out = self.relu3_2(out)

        out = self.pad3_3(out)
        out = self.conv3_3(out)
        out = self.relu3_3(out)

        out = self.pad3_4(out)
        out = self.conv3_4(out)
        pool3 = self.relu3_4(out)

        out, _ = self.maxpool3(pool3)

        out = self.pad4_1(out)
        out = self.conv4_1(out)
        out = self.relu4_1(out)

        out4 = out

        out = self.pad4_2(out)
        out = self.conv4_2(out)
        out = self.relu4_2(out)

        out = self.pad4_3(out)
        out = self.conv4_3(out)
        out = self.relu4_3(out)

        out = self.pad4_4(out)
        out = self.conv4_4(out)
        pool4 = self.relu4_4(out)

        out, _ = self.maxPool4(pool4)

        out = self.pad5_1(out)
        out = self.conv5_1(out)
        out = self.relu5_1(out)

        return [out1, out2, out3, out4, out]

def compute_style_loss(model, img, style):
    img_fea = model(img)
    sty_fea = model(style)

    loss_style = []
    for j in range(5):
        B, N, cH, cW = img_fea[j].size()
        fea_A = img_fea[j].reshape(B, N, -1)
        gram_A = torch.matmul(fea_A, fea_A.transpose(1, 2)) / (cH * cW)

        B, N, sH, sW = sty_fea[j].size()
        fea_B = sty_fea[j].reshape(B, N, -1)
        gram_B = torch.matmul(fea_B, fea_B.transpose(1, 2)) / (sH * sW)

        loss_style.append(torch.mean((gram_A - gram_B) ** 2))   # MSE
    loss_style = sum(loss_style)/len(loss_style)
    return loss_style

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vgg19 = VGG19()
vgg19.load_state_dict(torch.load("vgg_normalised_conv5_1.pth"))  # part of vgg_normalised.pth   Link: https://drive.google.com/file/d/1yXZwbbJCEamrPFSTKz-2PND3uSayhNAA/view?usp=drive_link
vgg19.to(device)
vgg19.eval()
for param in vgg19.parameters():
    param.requires_grad = False

x = torch.rand(1, 3, 256, 256).to(device)
y = torch.rand(1, 3, 128, 128).to(device)
ls = compute_style_loss(vgg19, x, y)
print(ls)