lucidrains / stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
https://thispersondoesnotexist.com
MIT License
3.71k stars 587 forks source link

OOM on GTX1080ti with 128 image size #170

Open lucasjinreal opened 3 years ago

lucasjinreal commented 3 years ago

I fixed my trained model load problem and the weight seems load successfully.

But when inference I got OOM, here is my config and inference code:


"""

Generate random images by trained StyleGANV2
"""

from stylegan2_pytorch.stylegan2_pytorch import Generator, StyleGAN2
import torch
import numpy as np
from alfred.dl.torch.common import device
import cv2

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate(
        [init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index)

def evaluate_in_chunks(max_batch_size, model, *args):
    split_args = list(
        zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
    chunked_outputs = [model(*i) for i in split_args]
    if len(chunked_outputs) == 1:
        return chunked_outputs[0]
    return torch.cat(chunked_outputs, dim=0)

def styles_def_to_tensor(styles_def):
    return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)

def truncate_style(g, tensor, trunc_psi=0.75):
    S = g.S
    latent_dim = g.G.latent_dim

    z = torch.randn([2000, latent_dim]).to(device)
    samples = evaluate_in_chunks(1, S, z).cpu().numpy()
    av = np.mean(samples, axis=0)
    av = np.expand_dims(av, axis=0)

    av_torch = torch.from_numpy(av).to(device)
    tensor = trunc_psi * (tensor - av_torch) + av_torch
    return tensor

def truncate_style_defs(g, w, trunc_psi=0.75):
    w_space = []
    for tensor, num_layers in w:
        tensor = truncate_style(g, tensor, trunc_psi=trunc_psi)
        w_space.append((tensor, num_layers))
    return w_space

def generate_truncated(g, S, G, style, noi, trunc_psi=0.75, num_image_tiles=8):
    w = map(lambda t: (S(t[0]), t[1]), style)
    w_truncated = truncate_style_defs(g, w, trunc_psi=trunc_psi)
    w_styles = styles_def_to_tensor(w_truncated)
    generated_images = evaluate_in_chunks(1, G, w_styles, noi)
    return generated_images.clamp_(0., 1.)

if __name__ == "__main__":
    weight_path = 'models/default/model_50.pt'

    # should same as your config file
    image_size = 128
    latent_dim = 512
    network_capacity = 16
    fq_dict_size = 256
    transparent = False
    attn_layers = []
    no_const = False
    model = StyleGAN2(image_size=image_size, latent_dim=latent_dim, network_capacity=network_capacity,
                  transparent=transparent, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const)
    # print(model)
    model.eval()
    ckpt = torch.load(weight_path)
    model.load_state_dict(ckpt['GAN'])
    model.to(device)

    print('StyleGAN2 loaded.')

    num_layers = model.G.num_layers
    nn = torch.randn([1, latent_dim]).to(device)
    tmp1 = tile(nn, 0, 1)
    tmp2 = nn.repeat(1, 1)

    tt = int(num_layers / 2)
    mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]
    print(mixed_latents)

    generated_images = generate_truncated(model,
                                          model.SE, model.GE, mixed_latents, 1)

OOM, what's the error could be? 128 input size should ok with my 12GB memory?

bogdan-ivan commented 3 years ago

You need to allocate more virtual memory (paging file).