Open keenhon opened 2 years ago
Sorry but this code is based on PyTorch. You'd have to change the entire codebase to run on Tensorflow, be it 1.x or 2.x.
Yes, It doesn't run on Tensorflow. But I think he meant that it throws an error on Colab when training a StyleGAN3 #https://github.com/googlecolab/colabtools/issues/2914
I've created a really simple Colab that reproduces the error (on a small dataset)
https://colab.research.google.com/drive/1u7FUYptdQvhq0FFjxV1xWBxk4bfdy59Q
When running:
!python train.py --outdir=./results --cfg=stylegan3-r --data=/content/circles-1024x1024.zip \
--gpus=1 --batch=32 --batch-gpu=4 --gamma=6.6 --mirror=1 --kimg=1 --snap=5 --metrics=None\
--resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
Output is:
Training options:
{
"G_kwargs": {
"class_name": "training.networks_stylegan3.Generator",
"z_dim": 512,
"w_dim": 512,
"mapping_kwargs": {
"num_layers": 2
},
"channel_base": 65536,
"channel_max": 1024,
"magnitude_ema_beta": 0.9988915792636801,
"conv_kernel": 1,
"use_radial_filters": true
},
"D_kwargs": {
"class_name": "training.networks_stylegan2.Discriminator",
"block_kwargs": {
"freeze_layers": 0
},
"mapping_kwargs": {},
"epilogue_kwargs": {
"mbstd_group_size": 4
},
"channel_base": 32768,
"channel_max": 512
},
"G_opt_kwargs": {
"class_name": "torch.optim.Adam",
"betas": [
0,
0.99
],
"eps": 1e-08,
"lr": 0.0025
},
"D_opt_kwargs": {
"class_name": "torch.optim.Adam",
"betas": [
0,
0.99
],
"eps": 1e-08,
"lr": 0.002
},
"loss_kwargs": {
"class_name": "training.loss.StyleGAN2Loss",
"r1_gamma": 6.6,
"blur_init_sigma": 0,
"blur_fade_kimg": 200.0
},
"data_loader_kwargs": {
"pin_memory": true,
"prefetch_factor": 2,
"num_workers": 3
},
"training_set_kwargs": {
"class_name": "training.dataset.ImageFolderDataset",
"path": "/content/circles-1024x1024.zip",
"use_labels": false,
"max_size": 248,
"xflip": true,
"resolution": 1024,
"random_seed": 0
},
"num_gpus": 1,
"batch_size": 32,
"batch_gpu": 4,
"metrics": [],
"total_kimg": 1,
"kimg_per_tick": 4,
"image_snapshot_ticks": 5,
"network_snapshot_ticks": 5,
"random_seed": 0,
"ema_kimg": 10.0,
"augment_kwargs": {
"class_name": "training.augment.AugmentPipe",
"xflip": 1,
"rotate90": 1,
"xint": 1,
"scale": 1,
"rotate": 1,
"aniso": 1,
"xfrac": 1,
"brightness": 1,
"contrast": 1,
"lumaflip": 1,
"hue": 1,
"saturation": 1
},
"ada_target": 0.6,
"resume_pkl": "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl",
"ada_kimg": 100,
"ema_rampup": null,
"run_dir": "./results/00001-stylegan3-r-circles-1024x1024-gpus1-batch32-gamma6.6"
}
Output directory: ./results/00001-stylegan3-r-circles-1024x1024-gpus1-batch32-gamma6.6
Number of GPUs: 1
Batch size: 32 images
Training duration: 1 kimg
Dataset path: /content/circles-1024x1024.zip
Dataset size: 248 images
Dataset resolution: 1024
Dataset labels: False
Dataset x-flips: True
Creating output directory...
Launching processes...
Loading training set...
Num images: 496
Image shape: [3, 1024, 1024]
Label shape: [0]
Constructing networks...
Resuming from "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl"
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
Generator Parameters Buffers Output shape Datatype
--- --- --- --- ---
mapping.fc0 262656 - [4, 512] float32
mapping.fc1 262656 - [4, 512] float32
mapping - 512 [4, 16, 512] float32
synthesis.input.affine 2052 - [4, 4] float32
synthesis.input 1048576 3081 [4, 1024, 36, 36] float32
synthesis.L0_36_1024.affine 525312 - [4, 1024] float32
synthesis.L0_36_1024 1049600 157 [4, 1024, 36, 36] float32
synthesis.L1_36_1024.affine 525312 - [4, 1024] float32
synthesis.L1_36_1024 1049600 157 [4, 1024, 36, 36] float32
synthesis.L2_52_1024.affine 525312 - [4, 1024] float32
synthesis.L2_52_1024 1049600 169 [4, 1024, 52, 52] float32
synthesis.L3_52_1024.affine 525312 - [4, 1024] float32
synthesis.L3_52_1024 1049600 157 [4, 1024, 52, 52] float32
synthesis.L4_84_1024.affine 525312 - [4, 1024] float32
synthesis.L4_84_1024 1049600 169 [4, 1024, 84, 84] float32
synthesis.L5_148_1024.affine 525312 - [4, 1024] float32
synthesis.L5_148_1024 1049600 169 [4, 1024, 148, 148] float16
synthesis.L6_148_1024.affine 525312 - [4, 1024] float32
synthesis.L6_148_1024 1049600 157 [4, 1024, 148, 148] float16
synthesis.L7_276_645.affine 525312 - [4, 1024] float32
synthesis.L7_276_645 661125 169 [4, 645, 276, 276] float16
synthesis.L8_276_406.affine 330885 - [4, 645] float32
synthesis.L8_276_406 262276 157 [4, 406, 276, 276] float16
synthesis.L9_532_256.affine 208278 - [4, 406] float32
synthesis.L9_532_256 104192 169 [4, 256, 532, 532] float16
synthesis.L10_1044_161.affine 131328 - [4, 256] float32
synthesis.L10_1044_161 41377 169 [4, 161, 1044, 1044] float16
synthesis.L11_1044_102.affine 82593 - [4, 161] float32
synthesis.L11_1044_102 16524 157 [4, 102, 1044, 1044] float16
synthesis.L12_1044_64.affine 52326 - [4, 102] float32
synthesis.L12_1044_64 6592 25 [4, 64, 1044, 1044] float16
synthesis.L13_1024_64.affine 32832 - [4, 64] float32
synthesis.L13_1024_64 4160 25 [4, 64, 1024, 1024] float16
synthesis.L14_1024_3.affine 32832 - [4, 64] float32
synthesis.L14_1024_3 195 1 [4, 3, 1024, 1024] float16
synthesis - - [4, 3, 1024, 1024] float32
--- --- --- --- ---
Total 15093151 5600 - -
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
Discriminator Parameters Buffers Output shape Datatype
--- --- --- --- ---
b1024.fromrgb 128 16 [4, 32, 1024, 1024] float16
b1024.skip 2048 16 [4, 64, 512, 512] float16
b1024.conv0 9248 16 [4, 32, 1024, 1024] float16
b1024.conv1 18496 16 [4, 64, 512, 512] float16
b1024 - 16 [4, 64, 512, 512] float16
b512.skip 8192 16 [4, 128, 256, 256] float16
b512.conv0 36928 16 [4, 64, 512, 512] float16
b512.conv1 73856 16 [4, 128, 256, 256] float16
b512 - 16 [4, 128, 256, 256] float16
b256.skip 32768 16 [4, 256, 128, 128] float16
b256.conv0 147584 16 [4, 128, 256, 256] float16
b256.conv1 295168 16 [4, 256, 128, 128] float16
b256 - 16 [4, 256, 128, 128] float16
b128.skip 131072 16 [4, 512, 64, 64] float16
b128.conv0 590080 16 [4, 256, 128, 128] float16
b128.conv1 1180160 16 [4, 512, 64, 64] float16
b128 - 16 [4, 512, 64, 64] float16
b64.skip 262144 16 [4, 512, 32, 32] float32
b64.conv0 2359808 16 [4, 512, 64, 64] float32
b64.conv1 2359808 16 [4, 512, 32, 32] float32
b64 - 16 [4, 512, 32, 32] float32
b32.skip 262144 16 [4, 512, 16, 16] float32
b32.conv0 2359808 16 [4, 512, 32, 32] float32
b32.conv1 2359808 16 [4, 512, 16, 16] float32
b32 - 16 [4, 512, 16, 16] float32
b16.skip 262144 16 [4, 512, 8, 8] float32
b16.conv0 2359808 16 [4, 512, 16, 16] float32
b16.conv1 2359808 16 [4, 512, 8, 8] float32
b16 - 16 [4, 512, 8, 8] float32
b8.skip 262144 16 [4, 512, 4, 4] float32
b8.conv0 2359808 16 [4, 512, 8, 8] float32
b8.conv1 2359808 16 [4, 512, 4, 4] float32
b8 - 16 [4, 512, 4, 4] float32
b4.mbstd - - [4, 513, 4, 4] float32
b4.conv 2364416 16 [4, 512, 4, 4] float32
b4.fc 4194816 - [4, 512] float32
b4.out 513 - [4, 1] float32
--- --- --- --- ---
Total 29012513 544 - -
Setting up augmentation...
Distributing across 1 GPUs...
Setting up training phases...
Exporting sample images...
Initializing logs...
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py", line 42, in tf
from tensorboard.compat import notf # noqa: F401
ImportError: cannot import name 'notf' from 'tensorboard.compat' (/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "train.py", line 286, in <module>
main() # pylint: disable=no-value-for-parameter
File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 829, in __call__
return self.main(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/usr/local/lib/python3.7/dist-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "train.py", line 281, in main
launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
File "train.py", line 96, in launch_training
subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
File "train.py", line 47, in subprocess_fn
training_loop.training_loop(rank=rank, **c)
File "/content/stylegan3/training/training_loop.py", line 238, in training_loop
stats_tfevents = tensorboard.SummaryWriter(run_dir)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 246, in __init__
self._get_file_writer()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 277, in _get_file_writer
self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/writer.py", line 76, in __init__
log_dir, max_queue, flush_secs, filename_suffix
File "/usr/local/lib/python3.7/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 72, in __init__
tf.io.gfile.makedirs(logdir)
File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 65, in __getattr__
return getattr(load_once(self), attr_name)
File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 97, in wrapper
cache[arg] = f(arg)
File "/usr/local/lib/python3.7/dist-packages/tensorboard/lazy.py", line 50, in load_once
module = load_fn()
File "/usr/local/lib/python3.7/dist-packages/tensorboard/compat/__init__.py", line 45, in tf
import tensorflow
File "/usr/local/lib/python3.7/dist-packages/tensorflow/__init__.py", line 51, in <module>
from ._api.v2 import compat
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/__init__.py", line 37, in <module>
from . import v1
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/__init__.py", line 30, in <module>
from . import compat
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/__init__.py", line 37, in <module>
from . import v1
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/v1/__init__.py", line 47, in <module>
from tensorflow._api.v2.compat.v1 import lite
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/__init__.py", line 9, in <module>
from . import experimental
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/__init__.py", line 8, in <module>
from . import authoring
File "/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/authoring/__init__.py", line 8, in <module>
from tensorflow.lite.python.authoring.authoring import compatible
File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/authoring/authoring.py", line 43, in <module>
from tensorflow.lite.python import convert
File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/convert.py", line 29, in <module>
from tensorflow.lite.python import util
File "/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/util.py", line 51, in <module>
from jax import xla_computation as _xla_computation
File "/usr/local/lib/python3.7/dist-packages/jax/__init__.py", line 59, in <module>
from .core import eval_context as ensure_compile_time_eval
File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 47, in <module>
import jax._src.pretty_printer as pp
File "/usr/local/lib/python3.7/dist-packages/jax/_src/pretty_printer.py", line 56, in <module>
CAN_USE_COLOR = _can_use_color()
File "/usr/local/lib/python3.7/dist-packages/jax/_src/pretty_printer.py", line 54, in _can_use_color
return sys.stdout.isatty()
AttributeError: 'Logger' object has no attribute 'isatty'
Setting tensorflow to 1.x, torch to 1.9.1 and torchvision to 0.10.1 seems to get it running.
%tensorflow_version 1.x
!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
Is there a better solution?
Try adding this:
!pip uninstall jax jaxlib -y
!pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html
per this feed: https://github.com/googlecolab/colabtools/issues/2926#issuecomment-1185979137
Try adding this:
Uninstall new JAX !pip uninstall jax jaxlib -y #GPU frontend !pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html #CPU frontend #!pip install jax[cpu]==0.3.10 #Downgrade Pytorch !pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html
per this feed: googlecolab/colabtools#2926 (comment)
Works for Me!!
Try adding this:
Works for Me!! thanks
Describe the bug Doesn't seem to run on Tensorflow 2.x
To Reproduce Run on Tensorflow 2.x