NVlabs / stylegan3

Official PyTorch implementation of StyleGAN3
Other
6.41k stars 1.13k forks source link

Resuming with a conditional network pkl on a new task with different number of classes #108

Open pcicales opened 2 years ago

pcicales commented 2 years ago

Hello,

Is it possible to resume with a network pkl on a new task with a different number of conditions? This is for stylegan2.

I got a tensor shape mismatch when trying to use a network trained on a 10 class 2048x2048 stylegan2 model, to train on a new 13 class 2048x2048 stylegan2 model.

leona commented 2 years ago

I am also wondering this. Faced the same issue when resuming from FFHQ 1024.

GTziolas commented 2 years ago

I am also facing this issue, trying to use transfer learning to resume training from the pretrained NVIDIA models. If you read the discussion here though, it seems that there is no way to solve this.

kirchhoffaron commented 9 months ago

I found a solution which discards the mapping networks, but preserves the weights of generator and discriminator. Feel free to try it:

import sys
import pickle
sys.path.append('../stylegan3')
import dnnlib
import legacy

# load your pretrained model with n classes
path = 'path_to_pretrained_with_n_classes'
with dnnlib.util.open_url(path) as fp:
    pretrained_model = legacy.load_network_pkl(fp)

# load a model initialized with m classes (interrupt the training after 0kimg)
path_new_classes = 'newly_initialized_model'
with dnnlib.util.open_url(path_new_classes) as fp:
    model_new_classes = legacy.load_network_pkl(fp)

# Copy the weights of the pretrained synthesis network
model_new_classes['G'].synthesis = pretrained_model['G'].synthesis
model_new_classes['G_ema'].synthesis = pretrained_model['G_ema'].synthesis

# save the we weights of the new mapping network
mapping_new_classes = model_new_classes['D'].mapping
c_dim = mapping_new_classes['D'].c_dim

# copy the weights of the pretrained discriminator
model_new_classes['D'] = pretrained_model['D']
model_new_classes['D'].mapping = mapping_new_classes
model_new_classes['D'].c_dim = c_dim

# save the combined model
with open('makeshift.pkl', 'wb') as f:
    pickle.dump(model_new_classes, f)