Fail to run in styleGAN3 #3

Open Ding3LI opened 2 months ago

Ding3LI commented 2 months ago

Hi, thanks for sharing such great repo for solving image degradation. I am currently trying to update the code from styleGAN2-ADA to styleGAN3. I simply modified some imports since styleGAN3 uses different name and path, and what I understand is that this image restoration method is mainly doing image generation (please correct me if I misunderstood). I pasted below:

File: robust_unsupervised/prelude.py

from typing import *

import copy
import os

import pickle

import functools
import sys
import torch.optim as optim
import tqdm
import dataclasses
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math

# >>>>>>>> Modified
import dnnlib
import legacy
# <<<<<<<<

import shutil
from functools import partial
import itertools
import warnings
from warnings import warn
import datetime
import torchvision.transforms.functional as TF
from torchvision.utils import save_image, make_grid

# >>>>>>>> Modified
# import training.networks as networks
import training.networks_stylegan3 as networks
# <<<<<<<<

from abc import ABC, abstractmethod, abstractstaticmethod, abstractclassmethod
from dataclasses import dataclass, field

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

warnings.filterwarnings("ignore", r"Named tensors and all their associated APIs.*")
warnings.filterwarnings("ignore", r"Arguments other than a weight enum.*")
warnings.filterwarnings("ignore", r"The parameter 'pretrained' is deprecated.*")
File: robust_unsupervised/io_utils.py
from robust_unsupervised.prelude import *
from robust_unsupervised.variables import *

import shutil
import torch_utils as torch_utils
import torch_utils.misc as misc
import contextlib

import PIL.Image as Image

def open_generator(pkl_path: str, refresh=True, float=True, ema=True) -> networks.Generator:
    print(f"Loading generator from {pkl_path}...")

    # >>>>>>>> Modified
    # with dnnlib.util.open_url(pkl_path) as fp:
    #     G = legacy.load_network_pkl(fp)["G_ema" if ema else "G"].cuda().eval()
    #     if float:
    #         G = G.float()

    with open(pkl_path, 'rb') as f:
        G = pickle.load(f)['G_ema'].cuda()
        if float:
            G = G.float()
    # <<<<<<<<

    if refresh:
        with torch.no_grad():
            old_G = G
            G = networks.Generator(*old_G.init_args, **old_G.init_kwargs).cuda()
            misc.copy_params_and_buffers(old_G, G, require_all=True)
            for param in G.parameters():
                param.requires_grad = False

    return G

import tyro
from dataclasses import dataclass
from typing import *

import sys
# >>>>>>>> Modified
# <<<<<<<<

class Config:
    name: str = f"restored_samples"
    "A name used to group log files."

    pkl_path: str = "pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl"
    "The location of the pretrained StyleGAN."

    dataset_path: str = "datasets/samples"
    "The location of the images to process."

    resolution: int = 1024
    "The resolution of your images. Images which are smaller or larger will be resized."

    global_lr_scale: float = 1.0
    "A global factor which scales up and down all learning rates. This may need adjustment for datasets other than faces."

    tasks: Literal["all", "single", "composed", "custom"] = "all"
    "Selects which tasks to run."

def parse_config() -> Config:
    return tyro.cli(Config)

So, I downloaded a new pre-trained model from styleGAN3 named stylegan3-r-ffhq-1024x1024.pkl. I keep using the same datasets that this repo provided (sample_1.png and sample_2.png), and all other codes remain same. \ However, the program can run for W and Wp successfully, but it will be terminated by an error when entering Wpp process. The error message is copied below:

>$ python run.py --dataset_path datasets/samples
Loading generator from pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl...
- 0000
W:   0%|     | 0/150 [00:00<?, ?it/s]
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
W: 100%|███████| 150/150 [00:27<00:00,  5.48it/s]
W+: 100%|██████| 150/150 [00:25<00:00,  5.79it/s]
W++:   0%|     | 0/150 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/project/robust-unsupervised/run.py", line 154, in <module>
    run_phase("W++", Wpp_variable, config.global_lr_scale * 0.005)
  File "/project/robust-unsupervised/run.py", line 24, in run_phase
    x = variable.to_image()
  File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 29, in to_image
    return self.render_image(self.to_input_tensor())
  File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 35, in render_image
    return (self.G.synthesis(ws, noise_mode="const", force_fp32=True) + 1.0) / 2.0
  File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/project/robust-unsupervised/stylegan3/training/networks_stylegan3.py", line 465, in forward
    misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
  File "/project/robust-unsupervised/stylegan3/torch_utils/misc.py", line 95, in assert_shape
    raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 8192, expected 16

Could you share any suggestions on solving this issue? Thanks.

My venv: TLDR This environment can run both styleGAN2-ADA and styleGAN3.

python: 3.10
cuda: 12.3
torch: 2.3.0
torchvision: 0.18.0
OS: Linux
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - binutils=2.40=h4852527_0
  - binutils_impl_linux-64=2.40=ha885e6a_0
  - binutils_linux-64=2.40=hdade7a5_3
  - blas=1.0=mkl
  - brotli-python=1.1.0=py310hc6cd4ac_1
  - bzip2=1.0.8=hd590300_5
  - c-compiler=1.7.0=hd590300_0
  - ca-certificates=2024.3.11=h06a4308_0
  - certifi=2024.2.2=pyhd8ed1ab_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - cuda=12.1.0=0
  - cuda-cccl=12.4.127=0
  - cuda-command-line-tools=12.1.1=0
  - cuda-compiler=12.4.1=0
  - cuda-cudart=12.1.105=0
  - cuda-cudart-dev=12.1.105=0
  - cuda-cudart-static=12.1.105=0
  - cuda-cuobjdump=12.4.127=0
  - cuda-cupti=12.1.105=0
  - cuda-cupti-static=12.1.105=0
  - cuda-cuxxfilt=12.4.127=0
  - cuda-demo-suite=12.4.127=0
  - cuda-documentation=12.4.127=0
  - cuda-driver-dev=12.4.127=0
  - cuda-gdb=12.4.127=0
  - cuda-libraries=12.1.0=0
  - cuda-libraries-dev=12.1.0=0
  - cuda-libraries-static=12.1.0=0
  - cuda-nsight=12.4.127=0
  - cuda-nsight-compute=12.4.1=0
  - cuda-nvcc=12.4.131=0
  - cuda-nvdisasm=12.4.127=0
  - cuda-nvml-dev=12.4.127=0
  - cuda-nvprof=12.4.127=0
  - cuda-nvprune=12.4.127=0
  - cuda-nvrtc=12.1.105=0
  - cuda-nvrtc-dev=12.1.105=0
  - cuda-nvrtc-static=12.1.105=0
  - cuda-nvtx=12.1.105=0
  - cuda-nvvp=12.4.127=0
  - cuda-opencl=12.4.127=0
  - cuda-opencl-dev=12.4.127=0
  - cuda-profiler-api=12.4.127=0
  - cuda-runtime=12.1.0=0
  - cuda-sanitizer-api=12.4.127=0
  - cuda-toolkit=12.1.0=0
  - cuda-tools=12.1.0=0
  - cuda-visual-tools=12.1.0=0
  - cudatoolkit=11.7.0=hd8887f6_10
  - cxx-compiler=1.7.0=h00ab1b0_0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.13.4=pyhd8ed1ab_0
  - freetype=2.12.1=h267a509_2
  - gcc=12.3.0=h915e2ae_6
  - gcc_impl_linux-64=12.3.0=h1562d66_6
  - gcc_linux-64=12.3.0=h6477408_3
  - gds-tools=
  - gmp=6.3.0=h59595ed_1
  - gmpy2=2.1.5=py310hc3586ac_0
  - gnutls=3.6.13=h85f3911_1
  - gxx=12.3.0=h915e2ae_6
  - gxx_impl_linux-64=12.3.0=h1562d66_6
  - gxx_linux-64=12.3.0=h4a1b8e8_3
  - icu=73.2=h59595ed_0
  - idna=3.7=pyhd8ed1ab_0
  - intel-openmp=2023.1.0=hdb19cb5_46306
  - jinja2=3.1.3=pyhd8ed1ab_0
  - jpeg=9e=h166bdaf_2
  - kernel-headers_linux-64=2.6.32=he073ed8_17
  - lame=3.100=h166bdaf_1003
  - lcms2=2.15=hfd0df8a_0
  - ld_impl_linux-64=2.40=h55db66e_0
  - lerc=4.0.0=h27087fc_0
  - libblas=3.9.0=1_h86c2bf4_netlib
  - libcblas=3.9.0=5_h92ddd45_netlib
  - libcublas=
  - libcublas-dev=
  - libcublas-static=
  - libcufft=
  - libcufft-dev=
  - libcufft-static=
  - libcufile=
  - libcufile-dev=
  - libcufile-static=
  - libcurand=
  - libcurand-dev=
  - libcurand-static=
  - libcusolver=
  - libcusolver-dev=
  - libcusolver-static=
  - libcusparse=
  - libcusparse-dev=
  - libcusparse-static=
  - libdeflate=1.17=h0b41bf4_0
  - libffi=3.4.2=h7f98852_5
  - libgcc-devel_linux-64=12.3.0=h2af2641_106
  - libgcc-ng=13.2.0=hc881cc4_6
  - libgfortran-ng=13.2.0=h69a702a_6
  - libgfortran5=13.2.0=h43f5ff8_6
  - libgomp=13.2.0=hc881cc4_6
  - libhwloc=2.10.0=default_h2fb2949_1000
  - libiconv=1.17=hd590300_2
  - libjpeg-turbo=2.0.0=h9bf148f_0
  - liblapack=3.9.0=5_h92ddd45_netlib
  - libnpp=
  - libnpp-dev=
  - libnpp-static=
  - libnsl=2.0.1=hd590300_0
  - libnvjitlink=12.1.105=0
  - libnvjitlink-dev=12.1.105=0
  - libnvjpeg=
  - libnvjpeg-dev=
  - libnvjpeg-static=
  - libnvvm-samples=12.1.105=0
  - libpng=1.6.43=h2797004_0
  - libsanitizer=12.3.0=h2af2641_6
  - libsqlite=3.45.3=h2797004_0
  - libstdcxx-devel_linux-64=12.3.0=h2af2641_106
  - libstdcxx-ng=13.2.0=h95c4c6d_6
  - libtiff=4.5.0=h6adf6a1_2
  - libuuid=2.38.1=h0b41bf4_0
  - libwebp-base=1.4.0=hd590300_0
  - libxcb=1.13=h7f98852_1004
  - libxcrypt=4.4.36=hd590300_1
  - libxml2=2.12.6=h232c23b_2
  - libzlib=1.2.13=hd590300_5
  - llvm-openmp=15.0.7=h0cdce71_0
  - markupsafe=2.1.5=py310h2372a71_0
  - mkl=2023.1.0=h213fc3f_46344
  - mpc=1.3.1=hfe3b2da_0
  - mpfr=4.2.1=h9458935_1
  - mpmath=1.3.0=pyhd8ed1ab_0
  - ncurses=6.4.20240210=h59595ed_0
  - nettle=3.6=he412f7d_0
  - networkx=3.3=pyhd8ed1ab_1
  - ninja=1.12.0=h00ab1b0_0
  - nsight-compute=2024.1.1.4=0
  - numpy=1.26.4=py310hb13e2d6_0
  - openh264=2.1.1=h780b84a_0
  - openjpeg=2.5.0=hfec8fc6_2
  - openssl=3.2.1=hd590300_1
  - pillow=9.4.0=py310h023d228_1
  - pip=24.0=pyhd8ed1ab_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.10.14=hd12c33a_0_cpython
  - python_abi=3.10=4_cp310
  - pytorch=2.3.0=py3.10_cuda12.1_cudnn8.9.2_0
  - pytorch-cuda=12.1=ha16c6d3_5
  - pytorch-mutex=1.0=cuda
  - pyyaml=6.0.1=py310h2372a71_1
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - setuptools=69.5.1=pyhd8ed1ab_0
  - sympy=1.12=pypyh9d50eac_103
  - sysroot_linux-64=2.12=he073ed8_17
  - tbb=2021.12.0=h00ab1b0_0
  - tk=8.6.13=noxft_h4845f30_101
  - torchaudio=2.3.0=py310_cu121
  - torchtriton=2.3.0=py310
  - torchvision=0.18.0=py310_cu121
  - typing_extensions=4.11.0=pyha770c72_0
  - tzdata=2024a=h0c530f3_0
  - urllib3=2.2.1=pyhd8ed1ab_0
  - wheel=0.43.0=pyhd8ed1ab_1
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - zlib=1.2.13=hd590300_5
  - zstd=1.5.5=hfc55251_0
  - pip:
      - absl-py==2.1.0
      - beautifulsoup4==4.12.3
      - cachetools==5.3.3
      - click==8.1.7
      - contourpy==1.2.1
      - cycler==0.12.1
      - fonttools==4.51.0
      - gdown==5.1.0
      - grpcio==1.62.2
      - imageio-ffmpeg==0.4.9
      - kiwisolver==1.4.5
      - markdown==3.6
      - matplotlib==3.8.4
      - nvidia-ml-py==12.535.161
      - nvitop==1.3.2
      - packaging==24.0
      - protobuf==5.26.1
      - psutil==5.9.8
      - pyparsing==3.1.2
      - pyspng==0.1.1
      - python-dateutil==2.9.0.post0
      - scipy==1.13.0
      - six==1.16.0
      - soupsieve==2.5
      - tensorboard==2.16.2
      - tensorboard-data-server==0.7.2
      - termcolor==2.4.0
      - tqdm==4.66.2
      - werkzeug==3.0.2
yohan-pg commented 2 months ago


Everything you are doing seems fine, it's just that in order to use StyleGAN3 you will need to modify the style injection in the same way I did for StyleGAN2. If you diff networks.py between this repo and the original stylegan2-ada code, you will be able to see what changes were made. This can be a bit tricky but in principle I don't see why it wouldn't work. If you are changing network architecture you may also need to update the learning rate to get proper results.

Let me know if you try doing this & need help!