diyiiyiii / StyTR-2

StyTr2 : Image Style Transfer with Transformers
352 stars 64 forks source link

About the metric score of StyTr2:Image Style Transfer with Transformers #26

Open Meimeiainaonao opened 1 year ago

Meimeiainaonao commented 1 year ago

Thanks for your sharing your code. It's a wonderful job I think~~

I have one question about the content loss score. I have applied StyTr2 to a dataset of 800 images, using your pre-trained model. To ensure consistency with your test settings, I resized all the images to 256x256 before calculating the content loss. However, I have noticed significant differences in the content loss values compared to what is reported in your papers.

I understand that variations in scores are expected due to the use of different images. Nonetheless, I found that the style loss scores exhibit a similar trend, while the content loss scores demonstrate noticeable discrepancies. So May I know how can you calculate the content loss? Is it possible to share your metric code or tell me where I am wrong?


    #!/usr/bin/env python3
    import argparse
    import os
    import torch
    import torch.nn as nn
    from tqdm import tqdm
    import cv2

    parser = argparse.ArgumentParser()
    parser.add_argument("--resize", type=int, default=256, help="resize_image_size")
    parser.add_argument("--content_dir", default=r'\input\content', help="the directory of content images")
    parser.add_argument("--style_dir", default=r'\input\style', help="the directory of style images")
    parser.add_argument("--stylized_dir", default=r\StyTR-2-main\output', required=False, help="the directory of stylized images")
    parser.add_argument("--log_path", default=r't\metrics', required=False, help="the directory of stylized images")
    parser.add_argument('--mode', type=int, default=1, help="0 for style loss, 1 for content loss, 2 for both")
    args = parser.parse_args()

    device = torch.device("cuda")
    vgg = nn.Sequential(
        nn.Conv2d(3, 3, (1, 1)),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(3, 64, (3, 3)),
        nn.ReLU(),  # relu1-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 64, (3, 3)),
        nn.ReLU(),  # relu1-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 128, (3, 3)),
        nn.ReLU(),  # relu2-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 128, (3, 3)),
        nn.ReLU(),  # relu2-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 256, (3, 3)),
        nn.ReLU(),  # relu3-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 512, (3, 3)),
        nn.ReLU(),  # relu4-1, this is the last layer used
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU()  # relu5-4
    )

    vgg.eval()
    vgg.load_state_dict(torch.load("../models/vgg_normalised.pth"))

    enc_1 = nn.Sequential(*list(vgg.children())[:4])  # input -> relu1_1
    enc_2 = nn.Sequential(*list(vgg.children())[4:11])  # relu1_1 -> relu2_1
    enc_3 = nn.Sequential(*list(vgg.children())[11:18])  # relu2_1 -> relu3_1
    enc_4 = nn.Sequential(*list(vgg.children())[18:31])  # relu3_1 -> relu4_1
    enc_5 = nn.Sequential(*list(vgg.children())[31:44])  # relu4_1 -> relu5_1

    enc_1.to(device)
    enc_2.to(device)
    enc_3.to(device)
    enc_4.to(device)
    enc_5.to(device)

    def calc_content_loss(input, target):
        assert (input.size() == target.size())
        return torch.nn.MSELoss()(input, target)

    content_dir = args.content_dir
    style_dir = args.style_dir
    stylized_dir = args.stylized_dir
    log_dir = args.log_path

    stylized_files = os.listdir(stylized_dir)
    folder_components = stylized_dir.split(os.path.sep)
    name = folder_components[-2]
    sub_name = folder_components[-1]
    log_path = os.path.join(args.log_path, name + '_log.txt')

    with torch.no_grad():
        if args.mode == 1 or args.mode == 2:
            loss_c_sum = 0.
            count = 0

            for i, stylized in enumerate(tqdm(stylized_files)):
                stylized_img = cv2.imread(stylized_dir + os.sep + stylized)   # stylized image
                if stylized_img is None or stylized_img.size == 0:
                    print('Failed to load the image:', stylized_dir + os.sep + stylized)
                stylized_img = cv2.resize(stylized_img, (args.resize, args.resize))
                name = stylized.split("_stylized_")  # parse the content image's name
                content_img = cv2.imread(content_dir + os.sep + name[0] + '.jpg')   # content image
                if content_img is None or content_img.size == 0:
                    print('Failed to load the image:', content_dir + os.sep + name[0] + '.jpg')

                content_img = cv2.resize(content_img, (args.resize, args.resize))

                stylized_img = torch.tensor(stylized_img, dtype=torch.float)
                stylized_img = stylized_img/255
                stylized_img = torch.unsqueeze(stylized_img, dim=0)
                stylized_img = stylized_img.permute([0, 3, 1, 2])
                stylized_img = stylized_img.cuda().to(device)

                content_img = torch.tensor(content_img, dtype=torch.float)
                content_img = content_img/255
                content_img = torch.unsqueeze(content_img, dim=0)
                content_img = content_img.permute([0, 3, 1, 2])
                content_img = content_img.cuda().to(device)

                loss_c = 0.

                o1 = enc_4(enc_3(enc_2(enc_1(stylized_img))))
                c1 = enc_4(enc_3(enc_2(enc_1(content_img))))

                loss_c += calc_content_loss(o1, c1)

                o2 = enc_5(o1)
                c2 = enc_5(c1)
                loss_c += calc_content_loss(o2, c2)

                print("Content Loss: {}".format(loss_c / 2))
                loss_c_sum += float(loss_c / 2)
                count += 1

            print("Total num: {}".format(count))
            print("Average Content Loss: {}".format(loss_c_sum / count))