universome / epigraf

[NeurIPS 2022] Official pytorch implementation of EpiGRAF
https://universome.github.io/epigraf
150 stars 6 forks source link

FFHQ training results #13

Open zhukaii opened 1 year ago

zhukaii commented 1 year ago

Thanks for your excellent work! I trained the model following python src/infra/launch.py hydra.run.dir=. exp_suffix= dataset=ffhq_posed dataset.resolution=512 model.training.gamma=0.1 or python src/infra/launch.py hydra.run.dir=. exp_suffix= dataset=ffhq_posed dataset.resolution=512 model.training.gamma=0.1 model.discriminator.camera_cond=true model.discriminator.camera_cond_drop_p=0.5 based on the provided FFHQ dataset (https://disk.yandex.ru/d/UmglE8U3YVbuLg), but I can not get reasonable results. The FID@2k value is around 20 after three days's training on 8 A100. Is there anything wrong with my implementation? Thanks for your help!

universome commented 1 year ago

Hi Kai, I am sorry for such a late reply. I think that you did everything right but after 3 days of training (especially on A100s) it should've yielded better scores — I've just launched training on my side to see that everything runs correctly. Note also that FID@2k of 20 would yield FID of ~13 I guess (our released checkpoint wit FID of 9.87 has FID@2k of 18.07), since computing FID on a smaller amount of images (2k vs 50k) greatly worsens FID (it starts thinking that there is mode collapse and penalizes it for it). We compute FID@2k instead of FID during training since it's faster.

Could you please tell what is the training speed in your case (in terms of sec/kimg)?

zhukaii commented 1 year ago

Thank you very much for your reply! Part of my training information is as follows:

tick 12400 kimg 49996.9 time 4d 01h 05m sec/tick 27.1 sec/kimg 6.72 maintenance 0.0 cpumem 7.63 gpumem 10.97 reserved 21.76 augment 0.107

Evaluating metrics for /root/code/experiments/ffhq_512_posed_epigraf_patch_beta_p64_mins0.125_ffhq_epi-None/output ... {"results": {"fid2k_full": 20.693541789820014}, "metric": "fid2k_full", "total_time": 30.95756506919861, "total_time_str": "31s", "num_gpus": 8, "snapshot_pkl": "network-snapshot-049996.pkl", "timestamp": 1671372690.0510468}

universome commented 1 year ago

Hi @zhukaii, I've just recalled that we used generator pose conditioning for FFHQ (we were specifying this in Sec 3.4), but I've just trained it with GPC, and couldn't get our reported FID anyway. I have no idea what the problem is and I am sorry for the problems you faced with our repo. I am currently working on it, but it is taking time... I have the old experiment from June which reached FID@2k of 18.08 (this translates to FID of ~10) after 32.5k kimgs (4.5 days of training on 4 V100s), but now I cannot get close to it in terms of FID curves after refactoring the repo. I am currently trying to find the problem in the new repo, without rolling back to the old version...

zhukaii commented 1 year ago

Thanks for your attention, I also wonder if this is due to some factors like pytorch version etc. Anyway, if there is any new progress, please let me know if it is convenient. Thank you very much again!

universome commented 1 year ago

Hi @zhukaii , I am confident that it is not due to a pytorch/environment version problem but rather the problems in our codebase itself. After some digging, I found three issues: 1) not using GPC (generator pose conditioning), 2) ADA augmentations, and 3) using positional embeddings to encode camera parameters in the discriminator. After fixing them, this yielded FID@2k of 19.67 after 18M seen images (2.5 days of training on 4 A100s), which corresponds to FID of ~11. I've updated the repo, and here is the command which I used to launch the training:

python src/infra/launch.py hydra.run.dir=. exp_suffix=dcondraw-gpc0.5-noaug dataset=ffhq_posed dataset.resolution=512 training.gamma=0.1 model.discriminator.camera_cond=true model.discriminator.camera_cond_raw=true num_gpus=4 training.resume=null model.generator.camera_cond=true training.augment.mode=noaug

FID@2k of 19.67 after 18M images is still worse than what I would expect, and I will need to dig a bit more — but right now our university cluster is under maintenance till January 8, so I will be able to resume debugging only after that. I am again sorry for the mess.

Also, note that enabling model.discriminator.camera_cond_raw=true for the Cats dataset will likely lead to flat geometry. To be honest, I am surprised with this behavior: it looks like for FFHQ, you should not give too much camera position information to the discriminator...

zhukaii commented 1 year ago

Thank you very much for your reply! The latest training information with the provided command (3 days of training on 8 A100s) is as follows: {"results": {"fid2k_full": 17.265873922286076}, "metric": "fid2k_full", "total_time": 31.182702779769897, "total_time_str": "31s", "num_gpus": 8, "snapshot_pkl": "network-snapshot-045964.pkl", "timestamp": 1673161545.9709427}. I think it is close to the origin results in the paper.

universome commented 1 year ago

Hi, @zhukaii , it is not close, because it is A100s instead of V100s :) The model should achieve FID@2k of ~18 at less than 25M seen images.

I will keep this issue open, if you do not mind till I find a better setup. I guess my current best result is FID@2k of 18.56 at 23.3M seen images when launching the model via:

python src/infra/launch.py hydra.run.dir=. exp_suffix=dcondraw-gpc0.5-noaug dataset=ffhq_posed dataset.resolution=512 training.gamma=0.05 model.discriminator.camera_cond=true model.discriminator.camera_cond_raw=true num_gpus=4 training.resume=null model.generator.camera_cond=true training.augment.mode=noaug

While I guess that one can achieve the expected result after several launches of the above command, I would still like to spend some time finding a better setup.

Also, note that you can compute FID for a checkpoint by launching:

python scripts/calc_metrics.py hydra.run.dir=. ckpt.network_pkl=/path/to/ckpt.pkl ckpt.reload_code=false img_resolution=512 metrics=fid50k_full data=./data/ffhq_512_posed.zip gpus=4 verbose=1
universome commented 1 year ago

Ok, if you reduce the batch size to 32 (the setup which we likely had on V100s since they do not fit the 64 batch size), then it gets FID@2k of 18.169959209896753 at 26400 kimgs.