patil-suraj / vit-vqgan

JAX implementation ViT-VQGAN
MIT License
77 stars 11 forks source link

feat: add stylegan loss #15

Closed borisdayma closed 2 years ago

borisdayma commented 2 years ago

Add stylegan loss. fixes #8, #7

TODO:

Feel free to take over and continue this PR

borisdayma commented 2 years ago

Started a test run, we'll see if it does anything interesting: https://wandb.ai/borisd13/vit-vqgan/runs/cmc397m5?workspace=user-borisd13

borisdayma commented 2 years ago

Run: https://wandb.ai/borisd13/vit-vqgan/runs/37z6d1vi?workspace=user-borisd13

Command:

python train_vit_vqvae.py \
    --output_dir output --overwrite_output_dir \
    --train_folder ../dataset/openimages/train \
    --valid_folder ../dataset/openimages/valid \
    --config_name config/base/model \
    --disc_config_name config/base/discriminator \
    --do_eval --do_train \
    --batch_size_per_node 64 --gradient_accumulation_steps 4 \
    --num_train_epochs 20 \
    --format rgb \
    --optim distributed_shampoo --block_size 1024 --beta1 0.9 --beta2 0.99 \
    --learning_rate 0.0001 --disc_learning_rate 0.0001 \
    --logging_steps 20 --eval_steps 100 --save_steps 200

Config:

{
  "attention_dropout": 0.0,
  "codebook_embed_dim": 4,
  "cost_e_latent": 0.25,
  "cost_l1": 0.0,
  "cost_l2": 1.0,
  "cost_lpips": 0.1,
  "cost_q_latent": 1.0,
  "cost_stylegan": 0.1,
  "cost_gradient_penalty": 1.0,
  "dropout": 0.0,
  "gradient_checkpointing": true,
  "hidden_act": "gelu",
  "hidden_size": 768,
  "image_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "ln_positions": "normformer",
  "mid_ffn_conv": true,
  "n_embed": 8192,
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 8,
  "post_attention_conv": false,
  "use_bias": false,
  "use_conv_patches": true,
  "use_glu": true
}

Notes:

borisdayma commented 2 years ago

Ready to be merged whenever you want

pcuenca commented 2 years ago

Ready to be merged whenever you want

LGTM. What do you think @patil-suraj ?