Yijunmaverick / UniversalStyleTransfer

The source code of NIPS17 'Universal Style Transfer via Feature Transforms'.
MIT License
597 stars 91 forks source link

Questions about vgg16 net #13

Open FantasyJXF opened 5 years ago

FantasyJXF commented 5 years ago

I use torch.utils.serialization.load_lua Python pkg to load the vgg_normalised_conv5_1.t7, the net is as followed:

In [7]: vgg1 = load_lua('models/vgg_normalised_conv1_1.t7')

In [8]: vgg1
Out[8]:
nn.Sequential {
  [input -> (0) -> (1) -> (2) -> (3) -> output]
  (0): nn.SpatialConvolution(3 -> 3, 1x1)
  (1): nn.SpatialReflectionPadding(1, 1, 1, 1)
  (2): nn.SpatialConvolution(3 -> 64, 3x3)
  (3): nn.ReLU
}

the encoder code is as followed:

class encoder1(nn.Module):
    def __init__(self,vgg1):
        super(encoder1,self).__init__()
        # dissemble vgg2 and decoder2 layer by layer
        # then resemble a new encoder-decoder network
        # 224 x 224
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
        self.conv1.bias = torch.nn.Parameter(vgg1.get(0).bias.float())
        # 224 x 224
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        # 226 x 226
        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg1.get(2).weight.float())
        self.conv2.bias = torch.nn.Parameter(vgg1.get(2).bias.float())

        self.relu = nn.ReLU(inplace=True)
        # 224 x 224
    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu(out)
        return out

My question is: in the original VGG16, there is no 1x1 convolution and padding to 226x226 procedure, but the code looks to do it, did I misunderstand the net?

One more thing, as you have said, the decoder is not no good that the generated picture is not so precise, and as for the artistic performance, I think use 4 decoder is better than all 5.

This is decoder 1-5. image

This is decoder 1-4. image

I think the bottom one looks better.

Look forward to your reply.

FantasyJXF commented 5 years ago

Another question is about encoder & decoder.

The deeper the conv layer is, the harder to reconstruct the input image.

I conduct the following test:

def identity_map(img):
    en5 = wct.e5(img)
    de5 = wct.d5(en5)
    vutils.save_image(de5.data.cpu().float(),os.path.join(args.outf,'de5.jpg'))

    en4 = wct.e4(img)
    de4 = wct.d4(en4)
    vutils.save_image(de4.data.cpu().float(),os.path.join(args.outf,'de4.jpg'))

    en3 = wct.e3(img)
    de3 = wct.d3(en3)
    vutils.save_image(de3.data.cpu().float(), os.path.join(args.outf,'de3.jpg'))

    en2 = wct.e2(img)
    de2 = wct.d2(en2)
    vutils.save_image(de2.data.cpu().float(), os.path.join(args.outf, 'de2.jpg'))

    en1 = wct.e1(img)
    de1 = wct.d1(en1)
    vutils.save_image(de1.data.cpu().float(), os.path.join(args.outf, 'de1.jpg'))

The result is:

屏幕快照 2019-05-17 下午6 10 16

I also try this way

def identity_map(img):
    en5 = wct.e5(img)
    de5 = wct.d5(en5)
    vutils.save_image(de5.data.cpu().float(),os.path.join(args.outf,'de5.jpg'))

    #de5 = de5.data.cpu().squeeze(0)
    en4 = wct.e4(de5)
    de4 = wct.d4(en4)
    vutils.save_image(de4.data.cpu().float(),os.path.join(args.outf,'de4.jpg'))

    #de4 = de4.data.cpu().squeeze(0)
    en3 = wct.e3(de4)
    de3 = wct.d3(en3)
    vutils.save_image(de3.data.cpu().float(), os.path.join(args.outf,'de3.jpg'))

    #de3 = de3.data.cpu().squeeze()
    en2 = wct.e2(de3)
    de2 = wct.d2(en2)
    vutils.save_image(de2.data.cpu().float(), os.path.join(args.outf, 'de2.jpg'))

    #de2 = de2.data.cpu().squeeze()
    en1 = wct.e1(de2)
    de1 = wct.d1(en1)
    vutils.save_image(de1.data.cpu().float(), os.path.join(args.outf, 'de1.jpg'))
屏幕快照 2019-05-17 下午6 24 50

Maybe that's one fo the reason which leads to the result, do you have any idea how to train the decoder?