universome / epigraf

[NeurIPS 2022] Official pytorch implementation of EpiGRAF
https://universome.github.io/epigraf
150 stars 6 forks source link
3d gans generation generative-adversarial-network

EpiGRAF: Rethinking training of 3D GANs

[website] [paper] [arxiv]

Generation examples for EpiGRAF

Code release progress:

Limitations / known problems:

Please, create an issue if you'll find any problems, bugs or have any questions with our repo.

Checkpoints (do not forget to update the repo to the latest version before using them):

Installation

To install and activate the environment, run the following command:

conda env create -f environment.yml -p env
conda activate ./env

This repo is built on top of StyleGAN3, so make sure that it runs on your system.

Sometimes, it falls down with the error:

AttributeError: module 'distutils' has no attribute 'version'

in which case you would need to install an older verion of setuptools:

pip install setuptools==59.5.0

Evaluation

Download the checkpoint above and save it into checkpoints/model.pkl. To generate the videos, run:

python scripts/inference.py hydra.run.dir=. ckpt.network_pkl=$(eval pwd)/checkpoints/model.pkl vis=video_grid camera=front_circle output_dir=results num_seeds=9

You can control the sampling resolution via the img_resolution argument.

To compute FID against the /path/to/dataset.zip dataset, run:

python scripts/calc_metrics.py hydra.run.dir=. ckpt.network_pkl=$(eval pwd)/checkpoints/model.pkl ckpt.reload_code=false img_resolution=512 metrics=fid50k_full data=/path/to/dataset.zip gpus=4 verbose=true

Data

Real data

For FFHQ and Cats, we use the camera poses provided by GRAM --- you can download them with their provided links. For Cats, we used exactly the same dataset as GRAM, we also upload it here (together with our pre-processed camera poses). For FFHQ, in contrast to some previous works (e.g., EG3D or GRAM), we do not re-crop it and use the original one (but with the camera poses provided for the cropped version by GRAM).

Megascans

We give the links to the Megascans datasets, as well as the rendering code and documentation on how to use it in a separate repo. We also prepared a script for simpler downloading of the Megascans datasets: you can download it via:

python scripts/data_scripts/download_megascans.py food /my/output/dir/
python scripts/data_scripts/download_megascans.py plants /my/output/dir/

How to pre-process the datasets

Data should be stored in a zip archive, the exact structure is not important, the script will use all the found images. Put your datasets into data/ directory. If you want to train with camera pose conditioning (either in Generator or Discriminator), then create a dataset.json with camera_angles dict of "<FILE_NAME>": [yaw, pitch, roll] key/values. Also, use model.discriminator.camera_cond=true model.discriminator.camera_cond_drop_p=0.5 command line arguments (or simply override them in the config). If you want to train on a custom dataset, then create the config for it like configs/dataset/my_dataset.yaml, specifying the necessary parameters (see other configs to get the idea on what should be specified).

Training

Commands

To launch training, run:

python src/infra/launch.py hydra.run.dir=. desc=<EXPERIMENT_NAME> dataset=<DATASET_NAME> dataset.resolution=<DATASET_RESOLUTION>  model.training.gamma=0.1 training.resume=null

To continue training, launch:

python src/infra/launch.py hydra.run.dir=. experiment_dir=<PATH_TO_EXPERIMENT> training.resume=latest

For Megascans Plants, we used class labels (for all the models). To enable class-conditional training, use training.use_labels=true command line argument (class annotations are located in dataset.json):

python src/infra/launch.py hydra.run.dir=. desc=default dataset=megascans_plants dataset.resolution=256  training.gamma=0.05 training.resume=null training.use_labels=true

Tips and tricks

Training on a cluster or with slurm

If you use slurm or some cluster training, you might be interested in our cluster training infrastructure. We leave our A100 cluster config in configs/env/raven.yaml as an example on how to structure the config environment in your own case. In principle, we provide two ways to train: locally and on cluster via slurm (by passing slurm=true when launching training). By default, the simple local environment is used, but you can switch to your custom one by specifying env=my_env argument (after your created my_env.yaml config in configs/env).

Evalution

At train time, we compute FID only on 2,048 fake images (versus all the available real images), since generating 50,000 images takes too long. To compute FID for 50k fake images after the training is done, run:

python scripts/calc_metrics.py hydra.run.dir=. ckpt.network_pkl=<CKPT_PATH> data=<PATH_TO_DATA> mirror=true gpus=4 metrics=fid50k_full img_resolution=<IMG_RESOLUTION>

If you have several checkpoints for the same experiment, you can alternatively pass ckpt.networks_dir=<CKPTS_DIR> instead of ckpt.network_pkl=<CKPT_PATH>. In this case, the script will find the best checkpoint out of all the available ones (measured by FID@2k) and computes the metrics for it.

Inference and visualization

Doing visualizations for a 3D GANs paper is pretty tedious, and we tried to structure/simplify this process as much as we could. We created a scripts which runs the necessary visualization types, where each visualization is defined by its own config. Below, we will provide several visualization types, the rest of them can be found in scripts/inference.py. Everywhere we use a direct path to a checkpoint via ckpt.network_pkl, but often it is easier to pass ckpt.networks_dir which should lead to a directory with checkpoints of your experiment --- the script will then take the best checkpoint based on the fid2k_full metric. You can combine different visualization types (location in configs/scripts/vis) with different camera paths (location in configs/scripts/camera).

Please see configs/scripts/inference.yaml for the available parameters and what they influence.

Main grid visualization

It's the visualization type we used for the teaser (as an image).

python scripts/inference.py hydra.run.dir=. ckpt.network_pkl=<CKPT_PATH> vis=front_grid camera=points output_dir=<OUTPUT_DIR> num_seeds=16 truncation_psi=0.7

A "zoom-in/out-and-fly-around" video

It's the visualization type we used for the teaser (as a video).

python scripts/inference.py hydra.run.dir=. ckpt.network_pkl=<CKPT_PATH> vis=video camera=front_circle output_dir=<OUTPUT_DIR> num_seeds=16 truncation_psi=0.7

Geometry extraction

You can also extract MRC volumes from the generator by running:

python scripts/extract_geometry.py hydra.run.dir=. hydra.run.dir=. ckpt.network_pkl=<PATH_TO_NETWORK_PKL> num_seeds=<NUM_SHAPES_TO_GENERATE> volume_res=256 save_mrc=true cube_size=<CUBE_SIZE_VALUE> output_dir=shapes

CUBE_SIZE_VALUE depends on your hyperparameters and should be somewhere in the [0.5, 1.0] range. You can then visualize it with ChimeraX. You can also extract PLY/OBJ shapes by setting save_ply=true and/or save_obj=true respectively — but then, you might need to tweak the thresh_value parameter for marching cubes.

Reporting bugs and issues

If something does not work as expected — please create an issue or email iskorokhodov@gmail.com.

License

This repo is built on top of StyleGAN3 and INR-GAN. This is why it is likely to be restricted by the NVidia license (but no idea to what extent).

Bibtex

@article{epigraf,
    title={EpiGRAF: Rethinking training of 3D GANs},
    author={Skorokhodov, Ivan and Tulyakov, Sergey and Wang, Yiqun and Wonka, Peter},
    journal={arXiv preprint arXiv:2206.10535},
    year={2022},
}