davidbau / rewriting

Rewriting a Deep Generative Model, ECCV 2020 (oral). Interactive tool to directly edit the rules of a GAN to synthesize scenes with objects added, removed, or altered. Change StyleGANv2 to make extravagant eyebrows, or horses wearing hats.
https://rewriting.csail.mit.edu/
MIT License
537 stars 76 forks source link

Unable to run custom model on Colab #2

Closed cyrilzakka closed 3 years ago

cyrilzakka commented 3 years ago

When attempting to load a custom mode, the following errors are raised:

Error(s) in loading state_dict for SeqStyleGAN2:
    Missing key(s) in state_dict: "noises.noise_0", "noises.noise_1", "noises.noise_2", "noises.noise_3", "noises.noise_4", "noises.noise_5", "noises.noise_6", "noises.noise_7", "noises.noise_8", "noises.noise_9", "noises.noise_10", "noises.noise_11", "noises.noise_12", "noises.noise_13", "noises.noise_14". 

I've tried setting the strict argument of load_state_dict to False but to no avail. After commenting out the noise layers in SeqStyleGAN2, everything runs until: gw = ganrewrite.SeqStyleGanRewriter(g, zds, layernum, cachedir='experiments') which then complains that the selected layer (e.g. layer 8) is not a Sequential layer. Any ideas?

davidbau commented 3 years ago

Hello! Couple issues here, I think I can help with.

First, I just pushed a change (a0f020d383396cd6fdde24b7a9a71f2dd888a687) to our styleganv2 port, to make the noises (and latent_avg) optional.

Second - my guess is that you need to construct the SeqStyleGan with the parameter mconv='seq', which splits out the modulated convolution layer into a sequence of steps to let the rewriting algorithm edit the underlying convolution directly.

I have updated the README (594272e93911f6ab1f734d1d9ee8fb3e99bdd745) to explain this.

cyrilzakka commented 3 years ago

Thank you! Fixed the first few errors I was having but now I'm getting input must be a CUDA tensor at the line gw = ganrewrite.SeqStyleGanRewriter(g, zds, layernum, cachedir='experiments').

Here is my full code:

from utils import zdataset, show, labwidget
from rewrite import ganrewrite, rewriteapp
import torch, copy, os, json
from utils.stylegan2.models import SeqStyleGAN2
from torchvision.utils import save_image
import utils.stylegan2

g = SeqStyleGAN2(512, style_dim=512, n_mlp=8, truncation=0.65, mconv='seq')
state_dict = torch.load('/content/tutorial_code/stylegan2.pt')
g.load_state_dict(state_dict['g_ema'], latent_avg=state_dict['latent_avg'])

layernum = 8 # or which ever layer you wish to modify
sample_size = 1000 # a small sample of images for computing statistics
zds = zdataset.z_dataset_for_model(g, size=sample_size)
gw = ganrewrite.SeqStyleGanRewriter(g, zds, layernum, cachedir='experiments')
davidbau commented 3 years ago

Right after you load_state_dict, try doing g.cuda(). Thanks for working through it with these missing details. I've updated the readme to mention this too (c95aa0b23dea200be514f827abd7498d612c7309).