Kibeom-Hong / AesPA-Net

Official Implementation of AesPA-Net: Aesthetic Pattern-Aware Style Transfer Networks
36 stars 2 forks source link

PyTorch model instead of Torch7 compatibility fix #9

Open FuouM opened 2 months ago

FuouM commented 2 months ago

I'm experimenting with your model and I realized that the provided vgg_normalised_conv5_1.t7 can't load with torch_file on Windows and other compatibility issues with latest versions of PyTorch. However with a simple tweak and a change of model file, anyone can run the model:

We download the vgg_normalised_conv5_1.pth file from https://github.com/pietrocarbo/deep-transfer/tree/master/models/autoencoder_vgg19/vgg19_5 (pietrocarbo's deep-transfer)

We then modify the VGGEncoder class to load from the state_dict directly

class VGGEncoder(nn.Module):
    def __init__(self, vgg_state_dict):
        super(VGGEncoder, self).__init__()

        self.pad = nn.ReflectionPad2d(1)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AvgPool2d(2)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices = False)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices = False)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices = False)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices = False)

        ###Level0###
        self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
        self.conv0.weight = nn.Parameter(vgg_state_dict['0.weight'])
        self.conv0.bias = nn.Parameter(vgg_state_dict['0.bias'])

        ###Level1###
        self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
        self.conv1_1.weight = nn.Parameter(vgg_state_dict['2.weight'])
        self.conv1_1.bias = nn.Parameter(vgg_state_dict['2.bias'])

        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
        self.conv1_2.weight = nn.Parameter(vgg_state_dict['5.weight'])
        self.conv1_2.bias = nn.Parameter(vgg_state_dict['5.bias'])

        ###Level2###
        self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
        self.conv2_1.weight = nn.Parameter(vgg_state_dict['9.weight'])
        self.conv2_1.bias = nn.Parameter(vgg_state_dict['9.bias'])

        self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
        self.conv2_2.weight = nn.Parameter(vgg_state_dict['12.weight'])
        self.conv2_2.bias = nn.Parameter(vgg_state_dict['12.bias'])

        ###Level3###
        self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
        self.conv3_1.weight = nn.Parameter(vgg_state_dict['16.weight'])
        self.conv3_1.bias = nn.Parameter(vgg_state_dict['16.bias'])

        self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
        self.conv3_2.weight = nn.Parameter(vgg_state_dict['19.weight'])
        self.conv3_2.bias = nn.Parameter(vgg_state_dict['19.bias'])

        self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
        self.conv3_3.weight = nn.Parameter(vgg_state_dict['22.weight'])
        self.conv3_3.bias = nn.Parameter(vgg_state_dict['22.bias'])

        self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
        self.conv3_4.weight = nn.Parameter(vgg_state_dict['25.weight'])
        self.conv3_4.bias = nn.Parameter(vgg_state_dict['25.bias'])

        ###Level4###
        self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
        self.conv4_1.weight = nn.Parameter(vgg_state_dict['29.weight'])
        self.conv4_1.bias = nn.Parameter(vgg_state_dict['29.bias'])

        self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 0)
        self.conv4_2.weight = nn.Parameter(vgg_state_dict['32.weight'])
        self.conv4_2.bias = nn.Parameter(vgg_state_dict['32.bias'])

        self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 0)
        self.conv4_3.weight = nn.Parameter(vgg_state_dict['35.weight'])
        self.conv4_3.bias = nn.Parameter(vgg_state_dict['35.bias'])

        self.conv4_4 = nn.Conv2d(512, 512, 3, 1, 0)
        self.conv4_4.weight = nn.Parameter(vgg_state_dict['38.weight'])
        self.conv4_4.bias = nn.Parameter(vgg_state_dict['38.bias'])

        ###Level5###
        self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 0)
        self.conv5_1.weight = nn.Parameter(vgg_state_dict['42.weight'])
        self.conv5_1.bias = nn.Parameter(vgg_state_dict['42.bias'])

Now one can simply load the model by:

pretrained_vgg = torch.load(path)
encoder = VGGEncoder(pretrained_vgg)