Open lucasjinreal opened 3 years ago
I fixed my trained model load problem and the weight seems load successfully.
But when inference I got OOM, here is my config and inference code:
""" Generate random images by trained StyleGANV2 """ from stylegan2_pytorch.stylegan2_pytorch import Generator, StyleGAN2 import torch import numpy as np from alfred.dl.torch.common import device import cv2 def tile(a, dim, n_tile): init_dim = a.size(dim) repeat_idx = [1] * a.dim() repeat_idx[dim] = n_tile a = a.repeat(*(repeat_idx)) order_index = torch.LongTensor(np.concatenate( [init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) return torch.index_select(a, dim, order_index) def evaluate_in_chunks(max_batch_size, model, *args): split_args = list( zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) def styles_def_to_tensor(styles_def): return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) def truncate_style(g, tensor, trunc_psi=0.75): S = g.S latent_dim = g.G.latent_dim z = torch.randn([2000, latent_dim]).to(device) samples = evaluate_in_chunks(1, S, z).cpu().numpy() av = np.mean(samples, axis=0) av = np.expand_dims(av, axis=0) av_torch = torch.from_numpy(av).to(device) tensor = trunc_psi * (tensor - av_torch) + av_torch return tensor def truncate_style_defs(g, w, trunc_psi=0.75): w_space = [] for tensor, num_layers in w: tensor = truncate_style(g, tensor, trunc_psi=trunc_psi) w_space.append((tensor, num_layers)) return w_space def generate_truncated(g, S, G, style, noi, trunc_psi=0.75, num_image_tiles=8): w = map(lambda t: (S(t[0]), t[1]), style) w_truncated = truncate_style_defs(g, w, trunc_psi=trunc_psi) w_styles = styles_def_to_tensor(w_truncated) generated_images = evaluate_in_chunks(1, G, w_styles, noi) return generated_images.clamp_(0., 1.) if __name__ == "__main__": weight_path = 'models/default/model_50.pt' # should same as your config file image_size = 128 latent_dim = 512 network_capacity = 16 fq_dict_size = 256 transparent = False attn_layers = [] no_const = False model = StyleGAN2(image_size=image_size, latent_dim=latent_dim, network_capacity=network_capacity, transparent=transparent, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const) # print(model) model.eval() ckpt = torch.load(weight_path) model.load_state_dict(ckpt['GAN']) model.to(device) print('StyleGAN2 loaded.') num_layers = model.G.num_layers nn = torch.randn([1, latent_dim]).to(device) tmp1 = tile(nn, 0, 1) tmp2 = nn.repeat(1, 1) tt = int(num_layers / 2) mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)] print(mixed_latents) generated_images = generate_truncated(model, model.SE, model.GE, mixed_latents, 1)
OOM, what's the error could be? 128 input size should ok with my 12GB memory?
You need to allocate more virtual memory (paging file).
I fixed my trained model load problem and the weight seems load successfully.
But when inference I got OOM, here is my config and inference code:
OOM, what's the error could be? 128 input size should ok with my 12GB memory?