lucidrains / stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
https://thispersondoesnotexist.com
MIT License
3.7k stars 585 forks source link

Can you teach me what your code is doing? #169

Open lucasjinreal opened 3 years ago

lucasjinreal commented 3 years ago

Hi, I am new to StyleGAN. I want write a simple inference demo to generate a single image using a simple noise latent vector by random.

I am currently got this:


"""

Generate random images by trained StyleGANV2
"""

from stylegan2_pytorch.stylegan2_simple import Generator, StyleGAN2
import torch
import numpy as np
from alfred.dl.torch.common import device
import cv2

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate(
        [init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index)

def evaluate_in_chunks(max_batch_size, model, *args):
    split_args = list(
        zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
    chunked_outputs = [model(*i) for i in split_args]
    if len(chunked_outputs) == 1:
        return chunked_outputs[0]
    return torch.cat(chunked_outputs, dim=0)

def styles_def_to_tensor(styles_def):
    return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)

def truncate_style(g, tensor, trunc_psi=0.75):
    S = g.S
    latent_dim = g.G.latent_dim

    z = torch.randn([2000, latent_dim]).to(device)
    samples = evaluate_in_chunks(1, S, z).cpu().numpy()
    av = np.mean(samples, axis=0)
    av = np.expand_dims(av, axis=0)

    av_torch = torch.from_numpy(av).to(device)
    tensor = trunc_psi * (tensor - av_torch) + av_torch
    return tensor

def truncate_style_defs(g, w, trunc_psi=0.75):
    w_space = []
    for tensor, num_layers in w:
        tensor = truncate_style(g, tensor, trunc_psi=trunc_psi)
        w_space.append((tensor, num_layers))
    return w_space

def generate_truncated(g, S, G, style, noi, trunc_psi=0.75, num_image_tiles=8):
    w = map(lambda t: (S(t[0]), t[1]), style)
    w_truncated = truncate_style_defs(g, w, trunc_psi=trunc_psi)
    w_styles = styles_def_to_tensor(w_truncated)
    generated_images = evaluate_in_chunks(1, G, w_styles, noi)
    return generated_images.clamp_(0., 1.)

if __name__ == "__main__":
    weight_path = 'models/default/model_50.pt'

    image_size = 512
    latent_dim = 256
    network_capacity = 16
    transparent = False
    attn_layers = []
    no_const = False
    model = StyleGAN2(image_size, latent_dim=latent_dim, network_capacity=network_capacity,
                  transparent=transparent, attn_layers=attn_layers, no_const=no_const)
    model.eval()
    ckpt = torch.load(weight_path)
    model.load_state_dict(ckpt['GAN'])
    model.to(device)

    print('generator loaded.')

    num_layers = model.GAN.num_layers
    nn = torch.randn([1, latent_dim]).to(device)
    tmp1 = tile(nn, 0, 1)
    tmp2 = nn.repeat(1, 1)

    tt = int(num_layers / 2)
    mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]

    generated_images = generate_truncated(model,
                                          model.SE, model.GAN.GE, mixed_latents, 1)

I am runing the default model with image size 512. But I got shape mistach error here:

rror: Error(s) in loading state_dict for StyleGAN2:
    Missing key(s) in state_dict: "G.blocks.5.to_rgb.upsample.1.f", "G.blocks.6.to_style1.weight", "G.blocks.6.to_style1.bias", "G.blocks.6.to_noise1.weight", "G.blocks.6.to_noise1.bias", "G.blocks.6.conv1.weight", "G.blocks.6.to_style2.weight", "G.blocks.6.to_style2.bias", "G.blocks.6.to_noise2.weight", "G.blocks.6.to_noise2.bias", "G.blocks.6.conv2.weight", "G.blocks.6.to_rgb.to_style.weight", "G.blocks.6.to_rgb.to_style.bias", "G.blocks.6.to_rgb.conv.weight", "D.blocks.6.downsample.0.f", "D.blocks.6.downsample.1.weight", "D.blocks.6.downsample.1.bias", "D.blocks.7.conv_res.weight", "D.blocks.7.conv_res.bias", "D.blocks.7.net.0.weight", "D.blocks.7.net.0.bias", "D.blocks.7.net.2.weight", "D.blocks.7.net.2.bias", "GE.blocks.5.to_rgb.upsample.1.f", "GE.blocks.6.to_style1.weight", "GE.blocks.6.to_style1.bias", "GE.blocks.6.to_noise1.weight", "GE.blocks.6.to_noise1.bias", "GE.blocks.6.conv1.weight", "GE.blocks.6.to_style2.weight", "GE.blocks.6.to_style2.bias", "GE.blocks.6.to_noise2.weight", "GE.blocks.6.to_noise2.bias", "GE.blocks.6.conv2.weight", "GE.blocks.6.to_rgb.to_style.weight", "GE.blocks.6.to_rgb.to_style.bias", "GE.blocks.6.to_rgb.conv.weight", "D_aug.D.blocks.6.downsample.0.f", "D_aug.D.blocks.6.downsample.1.weight", "D_aug.D.blocks.6.downsample.1.bias", "D_aug.D.blocks.7.conv_res.weight", "D_aug.D.blocks.7.conv_res.bias", "D_aug.D.blocks.7.net.0.weight", "D_aug.D.blocks.7.net.0.bias", "D_aug.D.blocks.7.net.2.weight", "D_aug.D.blocks.7.net.2.bias". 
    size mismatch for S.net.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for S.net.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for S.net.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for S.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for S.net.4.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for S.net.4.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for S.net.6.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for S.net.6.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for S.net.8.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for S.net.8.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for S.net.10.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).

I tried change the latent dim to either 512 or 256, but all failed to load the model trained.

If you can have a look at me code, it would be very much appreciated!

lucidrains commented 3 years ago

It's just https://github.com/lucidrains/stylegan2-pytorch#coding

lucasjinreal commented 3 years ago

@lucidrains Hi, thanks for your reply. Those code seems very clean and simple, but it's too simple. I wanna know what's happened inside.

I found a weried issue on my side, somehow I trained a model with latent dim is 256, because the error messages told me:

    size mismatch for G.blocks.2.to_noise1.weight: copying a param with shape torch.Size([256, 1]) from checkpoint, the shape in current model is torch.Size([512, 1]).

the G which is should be this model:

self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent,
                           attn_layers=attn_layers, no_const=no_const, fmap_max=fmap_max)

But I set lantent_dim to 256, I got shape missmatch of S model which is StyleVectorizer:

self.S = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)

Set to latent_dim 256 got:

    size mismatch for S.net.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).

this time mismatch from S model.

So, here is the problem:

Generate and StyleVectoriter saved weights shape not same, first one 256 while S is 512, how could this possible?

From the code you defined:

self.S = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent,
                           attn_layers=attn_layers, no_const=no_const, fmap_max=fmap_max)

they should all be latent_dim ??