NVlabs / stylegan3

Official PyTorch implementation of StyleGAN3
Other
6.46k stars 1.14k forks source link

Supported Tensorflow version #181

Open keenhon opened 2 years ago

keenhon commented 2 years ago

Describe the bug Doesn't seem to run on Tensorflow 2.x

To Reproduce Run on Tensorflow 2.x

PDillis commented 2 years ago

Sorry but this code is based on PyTorch. You'd have to change the entire codebase to run on Tensorflow, be it 1.x or 2.x.

fbarretto commented 2 years ago

Yes, It doesn't run on Tensorflow. But I think he meant that it throws an error on Colab when training a StyleGAN3 #https://github.com/googlecolab/colabtools/issues/2914

I've created a really simple Colab that reproduces the error (on a small dataset) https://colab.research.google.com/drive/1u7FUYptdQvhq0FFjxV1xWBxk4bfdy59Q

When running:

!python train.py --outdir=./results --cfg=stylegan3-r --data=/content/circles-1024x1024.zip \
    --gpus=1 --batch=32 --batch-gpu=4 --gamma=6.6 --mirror=1 --kimg=1 --snap=5 --metrics=None\
    --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl

Output is:


Training options:
{
  "G_kwargs": {
    "class_name": "training.networks_stylegan3.Generator",
    "z_dim": 512,
    "w_dim": 512,
    "mapping_kwargs": {
      "num_layers": 2
    },
    "channel_base": 65536,
    "channel_max": 1024,
    "magnitude_ema_beta": 0.9988915792636801,
    "conv_kernel": 1,
    "use_radial_filters": true
  },
  "D_kwargs": {
    "class_name": "training.networks_stylegan2.Discriminator",
    "block_kwargs": {
      "freeze_layers": 0
    },
    "mapping_kwargs": {},
    "epilogue_kwargs": {
      "mbstd_group_size": 4
    },
    "channel_base": 32768,
    "channel_max": 512
  },
  "G_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.0025
  },
  "D_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.002
  },
  "loss_kwargs": {
    "class_name": "training.loss.StyleGAN2Loss",
    "r1_gamma": 6.6,
    "blur_init_sigma": 0,
    "blur_fade_kimg": 200.0
  },
  "data_loader_kwargs": {
    "pin_memory": true,
    "prefetch_factor": 2,
    "num_workers": 3
  },
  "training_set_kwargs": {
    "class_name": "training.dataset.ImageFolderDataset",
    "path": "/content/circles-1024x1024.zip",
    "use_labels": false,
    "max_size": 248,
    "xflip": true,
    "resolution": 1024,
    "random_seed": 0
  },
  "num_gpus": 1,
  "batch_size": 32,
  "batch_gpu": 4,
  "metrics": [],
  "total_kimg": 1,
  "kimg_per_tick": 4,
  "image_snapshot_ticks": 5,
  "network_snapshot_ticks": 5,
  "random_seed": 0,
  "ema_kimg": 10.0,
  "augment_kwargs": {
    "class_name": "training.augment.AugmentPipe",
    "xflip": 1,
    "rotate90": 1,
    "xint": 1,
    "scale": 1,
    "rotate": 1,
    "aniso": 1,
    "xfrac": 1,
    "brightness": 1,
    "contrast": 1,
    "lumaflip": 1,
    "hue": 1,
    "saturation": 1
  },
  "ada_target": 0.6,
  "resume_pkl": "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl",
  "ada_kimg": 100,
  "ema_rampup": null,
  "run_dir": "./results/00001-stylegan3-r-circles-1024x1024-gpus1-batch32-gamma6.6"
}

Output directory:    ./results/00001-stylegan3-r-circles-1024x1024-gpus1-batch32-gamma6.6
Number of GPUs:      1
Batch size:          32 images
Training duration:   1 kimg
Dataset path:        /content/circles-1024x1024.zip
Dataset size:        248 images
Dataset resolution:  1024
Dataset labels:      False
Dataset x-flips:     True

Creating output directory...
Launching processes...
Loading training set...

Num images:  496
Image shape: [3, 1024, 1024]
Label shape: [0]

Constructing networks...
Resuming from "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl"
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.

Generator                      Parameters  Buffers  Output shape          Datatype
---                            ---         ---      ---                   ---     
mapping.fc0                    262656      -        [4, 512]              float32 
mapping.fc1                    262656      -        [4, 512]              float32 
mapping                        -           512      [4, 16, 512]          float32 
synthesis.input.affine         2052        -        [4, 4]                float32 
synthesis.input                1048576     3081     [4, 1024, 36, 36]     float32 
synthesis.L0_36_1024.affine    525312      -        [4, 1024]             float32 
synthesis.L0_36_1024           1049600     157      [4, 1024, 36, 36]     float32 
synthesis.L1_36_1024.affine    525312      -        [4, 1024]             float32 
synthesis.L1_36_1024           1049600     157      [4, 1024, 36, 36]     float32 
synthesis.L2_52_1024.affine    525312      -        [4, 1024]             float32 
synthesis.L2_52_1024           1049600     169      [4, 1024, 52, 52]     float32 
synthesis.L3_52_1024.affine    525312      -        [4, 1024]             float32 
synthesis.L3_52_1024           1049600     157      [4, 1024, 52, 52]     float32 
synthesis.L4_84_1024.affine    525312      -        [4, 1024]             float32 
synthesis.L4_84_1024           1049600     169      [4, 1024, 84, 84]     float32 
synthesis.L5_148_1024.affine   525312      -        [4, 1024]             float32 
synthesis.L5_148_1024          1049600     169      [4, 1024, 148, 148]   float16 
synthesis.L6_148_1024.affine   525312      -        [4, 1024]             float32 
synthesis.L6_148_1024          1049600     157      [4, 1024, 148, 148]   float16 
synthesis.L7_276_645.affine    525312      -        [4, 1024]             float32 
synthesis.L7_276_645           661125      169      [4, 645, 276, 276]    float16 
synthesis.L8_276_406.affine    330885      -        [4, 645]              float32 
synthesis.L8_276_406           262276      157      [4, 406, 276, 276]    float16 
synthesis.L9_532_256.affine    208278      -        [4, 406]              float32 
synthesis.L9_532_256           104192      169      [4, 256, 532, 532]    float16 
synthesis.L10_1044_161.affine  131328      -        [4, 256]              float32 
synthesis.L10_1044_161         41377       169      [4, 161, 1044, 1044]  float16 
synthesis.L11_1044_102.affine  82593       -        [4, 161]              float32 
synthesis.L11_1044_102         16524       157      [4, 102, 1044, 1044]  float16 
synthesis.L12_1044_64.affine   52326       -        [4, 102]              float32 
synthesis.L12_1044_64          6592        25       [4, 64, 1044, 1044]   float16 
synthesis.L13_1024_64.affine   32832       -        [4, 64]               float32 
synthesis.L13_1024_64          4160        25       [4, 64, 1024, 1024]   float16 
synthesis.L14_1024_3.affine    32832       -        [4, 64]               float32 
synthesis.L14_1024_3           195         1        [4, 3, 1024, 1024]    float16 
synthesis                      -           -        [4, 3, 1024, 1024]    float32 
---                            ---         ---      ---                   ---     
Total                          15093151    5600     -                     -       

Setting up PyTorch plugin "upfirdn2d_plugin"... Done.

Discriminator  Parameters  Buffers  Output shape         Datatype
---            ---         ---      ---                  ---     
b1024.fromrgb  128         16       [4, 32, 1024, 1024]  float16 
b1024.skip     2048        16       [4, 64, 512, 512]    float16 
b1024.conv0    9248        16       [4, 32, 1024, 1024]  float16 
b1024.conv1    18496       16       [4, 64, 512, 512]    float16 
b1024          -           16       [4, 64, 512, 512]    float16 
b512.skip      8192        16       [4, 128, 256, 256]   float16 
b512.conv0     36928       16       [4, 64, 512, 512]    float16 
b512.conv1     73856       16       [4, 128, 256, 256]   float16 
b512           -           16       [4, 128, 256, 256]   float16 
b256.skip      32768       16       [4, 256, 128, 128]   float16 
b256.conv0     147584      16       [4, 128, 256, 256]   float16 
b256.conv1     295168      16       [4, 256, 128, 128]   float16 
b256           -           16       [4, 256, 128, 128]   float16 
b128.skip      131072      16       [4, 512, 64, 64]     float16 
b128.conv0     590080      16       [4, 256, 128, 128]   float16 
b128.conv1     1180160     16       [4, 512, 64, 64]     float16 
b128           -           16       [4, 512, 64, 64]     float16 
b64.skip       262144      16       [4, 512, 32, 32]     float32 
b64.conv0      2359808     16       [4, 512, 64, 64]     float32 
b64.conv1      2359808     16       [4, 512, 32, 32]     float32 
b64            -           16       [4, 512, 32, 32]     float32 
b32.skip       262144      16       [4, 512, 16, 16]     float32 
b32.conv0      2359808     16       [4, 512, 32, 32]     float32 
b32.conv1      2359808     16       [4, 512, 16, 16]     float32 
b32            -           16       [4, 512, 16, 16]     float32 
b16.skip       262144      16       [4, 512, 8, 8]       float32 
b16.conv0      2359808     16       [4, 512, 16, 16]     float32 
b16.conv1      2359808     16       [4, 512, 8, 8]       float32 
b16            -           16       [4, 512, 8, 8]       float32 
b8.skip        262144      16       [4, 512, 4, 4]       float32 
b8.conv0       2359808     16       [4, 512, 8, 8]       float32 
b8.conv1       2359808     16       [4, 512, 4, 4]       float32 
b8             -           16       [4, 512, 4, 4]       float32 
b4.mbstd       -           -        [4, 513, 4, 4]       float32 
b4.conv        2364416     16       [4, 512, 4, 4]       float32 
b4.fc          4194816     -        [4, 512]             float32 
b4.out         513         -        [4, 1]               float32 
---            ---         ---      ---                  ---     
Total          29012513    544      -                    -       

Setting up augmentation...
Distributing across 1 GPUs...
Setting up training phases...
Exporting sample images...
Initializing logs...
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py", line 42, in tf
    from tensorboard.compat import notf  # noqa: F401
ImportError: cannot import name 'notf' from 'tensorboard.compat' (/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 286, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "train.py", line 281, in main
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
  File "train.py", line 96, in launch_training
    subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
  File "train.py", line 47, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/content/stylegan3/training/training_loop.py", line 238, in training_loop
    stats_tfevents = tensorboard.SummaryWriter(run_dir)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 246, in __init__
    self._get_file_writer()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 277, in _get_file_writer
    self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 76, in __init__
    log_dir, max_queue, flush_secs, filename_suffix
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 72, in __init__
    tf.io.gfile.makedirs(logdir)
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 65, in __getattr__
    return getattr(load_once(self), attr_name)
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 97, in wrapper
    cache[arg] = f(arg)
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 50, in load_once
    module = load_fn()
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py", line 45, in tf
    import tensorflow
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/__init__.py", line 51, in <module>
    from ._api.v2 import compat
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/__init__.py", line 37, in <module>
    from . import v1
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/__init__.py", line 30, in <module>
    from . import compat
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/__init__.py", line 37, in <module>
    from . import v1
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/v1/__init__.py", line 47, in <module>
    from tensorflow._api.v2.compat.v1 import lite
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/__init__.py", line 9, in <module>
    from . import experimental
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/__init__.py", line 8, in <module>
    from . import authoring
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/authoring/__init__.py", line 8, in <module>
    from tensorflow.lite.python.authoring.authoring import compatible
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/authoring/authoring.py", line 43, in <module>
    from tensorflow.lite.python import convert
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/convert.py", line 29, in <module>
    from tensorflow.lite.python import util
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/util.py", line 51, in <module>
    from jax import xla_computation as _xla_computation
  File "/usr/local/lib/python3.7/dist-packages/jax/__init__.py", line 59, in <module>
    from .core import eval_context as ensure_compile_time_eval
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 47, in <module>
    import jax._src.pretty_printer as pp
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/pretty_printer.py", line 56, in <module>
    CAN_USE_COLOR = _can_use_color()
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/pretty_printer.py", line 54, in _can_use_color
    return sys.stdout.isatty()
AttributeError: 'Logger' object has no attribute 'isatty'

Setting tensorflow to 1.x, torch to 1.9.1 and torchvision to 0.10.1 seems to get it running.

%tensorflow_version 1.x
!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Is there a better solution?

makeitrad commented 2 years ago

Try adding this:

Uninstall new JAX

!pip uninstall jax jaxlib -y

GPU frontend

!pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

CPU frontend

!pip install jax[cpu]==0.3.10

Downgrade Pytorch

!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html

per this feed: https://github.com/googlecolab/colabtools/issues/2926#issuecomment-1185979137

MuzammilAhmad commented 2 years ago

Try adding this:

Uninstall new JAX !pip uninstall jax jaxlib -y #GPU frontend !pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html #CPU frontend #!pip install jax[cpu]==0.3.10 #Downgrade Pytorch !pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html

per this feed: googlecolab/colabtools#2926 (comment)

Works for Me!!

Tanzman commented 2 years ago

Try adding this:

Uninstall new JAX !pip uninstall jax jaxlib -y #GPU frontend !pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html #CPU frontend #!pip install jax[cpu]==0.3.10 #Downgrade Pytorch !pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Works for Me!! thanks