lukemelas / fixed-point-diffusion-models

40 stars 6 forks source link

Reproduction problem #1

Open weleen opened 4 months ago

weleen commented 4 months ago

Hi,

Thanks for your open-sourced code, I try to reproduce the results of fpdm on ImageNet, but I find the generated images are much worse than original fast-DiT model. Could you help me check if I run the code correctly?

I try to train model with fixed-point layers and disable zero_snr/v_pred, the arguments are as follow:

{
    "ckpt_every": 100000,
    "compile": false,
    "dataset_name": "imagenet256",
    "debug": false,
    "dino_supervised": false,
    "dino_supervised_dim": 768,
    "epochs": 1400,
    "feature_path": "/home/yiming/mnt_dataset/ImageNet/ILSVRC2012/vae_features/",
    "fixed_point": true,
    "fixed_point_no_grad_max_iters": 10,
    "fixed_point_no_grad_min_iters": 0,
    "fixed_point_post_depth": 1,
    "fixed_point_pre_depth": 1,
    "fixed_point_pre_post_timestep_conditioning": false,
    "fixed_point_with_grad_max_iters": 12,
    "fixed_point_with_grad_min_iters": 1,
    "flow": false,
    "global_batch_size": 512,
    "global_seed": 0,
    "image_size": 256,
    "log_every": 100,
    "log_with": "wandb",
    "lr": 0.0001,
    "model": "DiT-XL/2",
    "name": "006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False",
    "num_classes": 1000,
    "num_workers": 4,
    "output_dir": "results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False",
    "output_subdir": "runs",
    "predict_v": false,
    "reproducibility": {
        "command_line": "python train.py --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1",
        "time": "Wed Feb 21 16:49:11 2024"
    },
    "resume": null,
    "unsupervised": false,
    "use_zero_terminal_snr": false
}

then test by running

python sample.py --image_size 256 --global_seed 1 --ckpt ./results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt  --sample_index_end 100 --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 250

the testing arguments are as follows:

{
    "adaptive": false,
    "adaptive_type": "increasing",
    "batch_size": 32,
    "cfg_scale": 4.0,
    "ckpt": "/home/yiming/project/Acceleration/fixed-point-diffusion-models/results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt",
    "dataset_name": "imagenet256",
    "ddim": false,
    "debug": false,
    "dino_supervised": false,
    "dino_supervised_dim": 768,
    "fixed_point": true,
    "fixed_point_iters": 26,
    "fixed_point_post_depth": 1,
    "fixed_point_pre_depth": 1,
    "fixed_point_pre_post_timestep_conditioning": false,
    "fixed_point_reuse_solution": false,
    "flow": false,
    "global_seed": 1,
    "image_size": 256,
    "iteration_controller": null,
    "model": "DiT-XL/2",
    "num_classes": 1000,
    "num_sampling_steps": 250,
    "output_dir": "samples/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/num_sampling_steps-250--cfg_scale-4.0--fixed_point_iters-26--fixed_point_reuse_solution-False--fixed_point_pptc-False",
    "predict_v": false,
    "reproducibility": {
        "command_line": "python sample.py --image_size 256 --global_seed 1 --ckpt /home/yiming/project/Acceleration/fixed-point-diffusion-models/results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt --sample_index_end 100 --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 250",
        "git_has_uncommitted_changes": true,
        "git_root": "/mnt/dongxu-fs2/data-ssd/yiming/project/Acceleration/fixed-point-diffusion-models",
        "git_url": "https://github.com/lukemelas/fixed-point-diffusion-models/tree/519e1286ba27c34e177e05962c5d9e66edce31e6",
        "time": "Fri Mar 22 12:54:29 2024"
    },
    "sample_index_end": 100,
    "sample_index_start": 0,
    "unsupervised": false,
    "use_zero_terminal_snr": false,
    "vae": "mse"
}

I take some samples: 00003--974--geyser 00004--088--macaw 00005--979--valley 00006--417--balloon

Additionally, when do we expect to have the pre-trained model?

Best Regards

lukemelas commented 4 months ago

Hi, thanks for your issue!

The commands look right to me — how long are you training for (and on what hardware)?

Those commands should produce a model that is slightly (but not much much) worse than an equivalent DiT model when sampling for 250 steps (because the DiT model has so many more parameters, and you’re sampling for so many steps). The results you posted look strange — the balloon class for example should give balloons that look ok-ish even toward the beginning of training.

For the pretrained models, I’m sorry! I meant to release them but got caught up in other things. I’ll get on that.

Hope this helps and I’ll try to update with the pretrained models soon!

Best, Luke

weleen commented 4 months ago

Hi, thanks for your issue!

The commands look right to me — how long are you training for (and on what hardware)?

Those commands should produce a model that is slightly (but not much much) worse than an equivalent DiT model when sampling for 250 steps (because the DiT model has so many more parameters, and you’re sampling for so many steps). The results you posted look strange — the balloon class for example should give balloons that look ok-ish even toward the beginning of training.

For the pretrained models, I’m sorry! I meant to release them but got caught up in other things. I’ll get on that.

Hope this helps and I’ll try to update with the pretrained models soon!

Best, Luke

Dear Luke,

Thank you for your prompt response.

I trained the model on 8 GTX 3090 for 600k steps, I also run sample.py by adopting the model trained via fast-DiT, it looks good.

The generated images are really wired, do you have any idea about this issue?

Best Regards, Yiming