yuval-alaluf / hyperstyle

Official Implementation for "HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing" (CVPR 2022) https://arxiv.org/abs/2111.15666
https://yuval-alaluf.github.io/hyperstyle/
MIT License
1.01k stars 115 forks source link

Training on custom dataset. #72

Closed tsshubhamv closed 1 year ago

tsshubhamv commented 1 year ago

Hi @yuval-alaluf, first of all thank you for this repo.

I have been trying to train hyperstyle with custom dataset, but unable to achieve success. There are certain things I would like if someone can clear this. So, basically I have already pretrained stylegan2-ada for this custom dataset and using that generator here but facing issues while training for example. RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x256 and 512x512)

I am using this script to run training.

dataset_type="sg2-ada-smile_256"
encoder_type="SharedWeightsHyperNetResNetSeparable"
exp_dir="experiments/smile_256"
workers=8
batch_size=8
test_batch_size=4
test_workers=4
save_interval=2000
lpips_lambda=0.8
l2_lambda=0.8
id_lambda=0
n_iters_per_batch=5
max_val_batches=150
output_size=256
layers_to_tune="0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24"

cmd = f'''python scripts/train.py \
  --dataset_type={dataset_type} \
  --encoder_type={encoder_type} \
  --exp_dir={exp_dir} \
  --workers={workers} \
  --batch_size={batch_size} \
  --test_batch_size={test_batch_size} \
  --test_workers={test_workers} \
  --save_interval={save_interval} \
  --lpips_lambda={lpips_lambda} \
  --l2_lambda={l2_lambda} \
  --id_lambda={id_lambda} \
  --n_iters_per_batch={n_iters_per_batch} \
  --max_val_batches={max_val_batches} \
  --stylegan_weights={"artifacts/sg2-ada-smile_256.pt"} \
  --output_size={output_size} \
  # --load_w_encoder \
  # --w_encoder_checkpoint_path={"artifacts/e2e.pth"} \
  --layers_to_tune={layers_to_tune}'''

!{cmd}

Will appreciate any kind of help or input. Thanks.

yuval-alaluf commented 1 year ago

Where do you get this error? I would assume that it has something to do with your choice of layers_to_tune. You may have provided layers that do not exist in a generator with an output size of 256. However, if you provide a full stack trace I could try to better guide you regarding where this error could be coming from.

tsshubhamv commented 1 year ago

Error 1:

dataset_type="sg2-ada-smile_256"
encoder_type="SharedWeightsHyperNetResNetSeparable"
exp_dir="experiments/smile_256"
workers=8
batch_size=8
test_batch_size=4
test_workers=4
save_interval=2000
lpips_lambda=0.8
l2_lambda=0.8
id_lambda=0
n_iters_per_batch=5
max_val_batches=150
output_size=256
layers_to_tune="0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24"

cmd = f'''python scripts/train.py \
  --dataset_type={dataset_type} \
  --encoder_type={encoder_type} \
  --exp_dir={exp_dir} \
  --workers={workers} \
  --batch_size={batch_size} \
  --test_batch_size={test_batch_size} \
  --test_workers={test_workers} \
  --save_interval={save_interval} \
  --lpips_lambda={lpips_lambda} \
  --l2_lambda={l2_lambda} \
  --id_lambda={id_lambda} \
  --n_iters_per_batch={n_iters_per_batch} \
  --max_val_batches={max_val_batches} \
  --stylegan_weights={"artifacts/sg2-ada-smile_256.pt"} \
  --output_size={output_size} \
  # --load_w_encoder \
  # --w_encoder_checkpoint_path={"artifacts/e2e.pth"} \
  --layers_to_tune={layers_to_tune}'''

!{cmd}

When using above options output is size mismatch:

{'batch_size': 8,
 'board_interval': 50,
 'checkpoint_path': None,
 'dataset_type': 'sg2-ada-smile_256',
 'encoder_type': 'SharedWeightsHyperNetResNetSeparable',
 'exp_dir': 'experiments/smile_256',
 'id_lambda': 0.0,
 'image_interval': 100,
 'input_nc': 6,
 'l2_lambda': 0.8,
 'layers_to_tune': '0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24',
 'learning_rate': 0.0001,
 'load_w_encoder': False,
 'lpips_lambda': 0.8,
 'max_steps': 500000,
 'max_val_batches': 150,
 'moco_lambda': 0,
 'n_iters_per_batch': 5,
 'optim_name': 'ranger',
 'output_size': 256,
 'save_interval': 2000,
 'stylegan_weights': 'artifacts/sg2-ada-smile_256.pt',
 'test_batch_size': 4,
 'test_workers': 4,
 'train_decoder': False,
 'val_interval': 1000,
 'w_encoder_checkpoint_path': None,
 'w_encoder_type': 'WEncoder',
 'workers': 8}
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100% 83.3M/83.3M [00:00<00:00, 263MB/s]
Loading hypernet weights from resnet34!
Loading decoder weights from pretrained path: artifacts/sg2-ada-smile_256.pt
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:135: UserWarning: Using 'weights' as positional parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) instead.
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100% 233M/233M [00:02<00:00, 95.1MB/s]
Downloading: "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/master/lpips/weights/v0.1/alex.pth" to /root/.cache/torch/hub/checkpoints/alex.pth
100% 5.87k/5.87k [00:00<00:00, 15.2MB/s]
Loading dataset for sg2-ada-smile_256
Number of training samples: 2014
Number of test samples: 502
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
Traceback (most recent call last):
  File "/content/hyperstyle/scripts/train.py", line 32, in <module>
    main()
  File "/content/hyperstyle/scripts/train.py", line 20, in main
    coach.train()
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 135, in train
    x, y, y_hat, loss_dict, id_logs, w_inversion = self.perform_forward_on_batch(batch, train=True)
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 103, in perform_forward_on_batch
    y_hat, latent, weights_deltas, codes, w_inversion = self.net.forward(x,
  File "/content/hyperstyle/./models/hyperstyle.py", line 86, in forward
    images, result_latent = self.decoder([codes],
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/hyperstyle/./models/stylegan2/model.py", line 492, in forward
    styles = [self.style(s) for s in styles]
  File "/content/hyperstyle/./models/stylegan2/model.py", line 492, in <listcomp>
    styles = [self.style(s) for s in styles]
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/hyperstyle/./models/stylegan2/model.py", line 150, in forward
    out = F.linear(input, weight * self.scale)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x256 and 512x512)
tsshubhamv commented 1 year ago

When using different or default options:

a = f'''python scripts/train.py \
  --dataset_type={dataset_type} \
  --encoder_type={encoder_type} \
  --exp_dir={exp_dir}\
  --stylegan_weights={"artifacts/sg2-ada-smile_256.pt"}
  '''
!{a}

I get the following error:

Loading hypernet weights from resnet34!
Loading decoder weights from pretrained path: artifacts/sg2-ada-smile_256.pt
Traceback (most recent call last):
  File "/content/hyperstyle/scripts/train.py", line 32, in <module>
    main()
  File "/content/hyperstyle/scripts/train.py", line 19, in main
    coach = Coach(opts)
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 35, in __init__
    self.net = HyperStyle(self.opts).to(self.device)
  File "/content/hyperstyle/./models/hyperstyle.py", line 26, in __init__
    self.load_weights()
  File "/content/hyperstyle/./models/hyperstyle.py", line 59, in load_weights
    self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "convs.12.conv.weight", "convs.12.conv.blur.kernel", "convs.12.conv.modulation.weight", "convs.12.conv.modulation.bias", "convs.12.noise.weight", "convs.12.activate.bias", "convs.13.conv.weight", "convs.13.conv.modulation.weight", "convs.13.conv.modulation.bias", "convs.13.noise.weight", "convs.13.activate.bias", "convs.14.conv.weight", "convs.14.conv.blur.kernel", "convs.14.conv.modulation.weight", "convs.14.conv.modulation.bias", "convs.14.noise.weight", "convs.14.activate.bias", "convs.15.conv.weight", "convs.15.conv.modulation.weight", "convs.15.conv.modulation.bias", "convs.15.noise.weight", "convs.15.activate.bias", "to_rgbs.6.bias", "to_rgbs.6.upsample.kernel", "to_rgbs.6.conv.weight", "to_rgbs.6.conv.modulation.weight", "to_rgbs.6.conv.modulation.bias", "to_rgbs.7.bias", "to_rgbs.7.upsample.kernel", "to_rgbs.7.conv.weight", "to_rgbs.7.conv.modulation.weight", "to_rgbs.7.conv.modulation.bias", "noises.noise_13", "noises.noise_14", "noises.noise_15", "noises.noise_16".
tsshubhamv commented 1 year ago

@yuval-alaluf I have tried with options and default options. It would great if you can enlighten what I'm doing wrong.

yuval-alaluf commented 1 year ago

For the first error, RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x256 and 512x512), it seems like your images are not resized to the correct size. Try to make sure that all your input images are square-shaped.

For the second error, you need to make sure that you set --output_size=256.

Try making these changes to see if this helps solve your issues.

tsshubhamv commented 1 year ago

I'm pretty sure the all the images are square shaped and of size 256x256. Also output_size is 256. I have tried with these parameters once I'll try it again.