sail-sg / MDT

Masked Diffusion Transformer is the SOTA for image synthesis. (ICCV 2023)
Apache License 2.0
500 stars 35 forks source link

[request] script for generating imagenet related images #13

Closed jS5t3r closed 1 year ago

jS5t3r commented 1 year ago

I want to generate images that are close to the imagenet dataset by using your provided weights:

from huggingface_hub import snapshot_download
models_path = snapshot_download("shgao/MDT-XL2")
ckpt_model_path = os.path.join(models_path, "mdt_xl2_v1_ckpt.pt")

Any suggestions how to do that?

Are these the scripts? https://github.com/sail-sg/MDT/blob/main/run_sample.sh

gasvn commented 1 year ago

https://github.com/sail-sg/MDT/blob/main/run_sample.sh includes the inference and evaluation. You can also modify https://github.com/sail-sg/MDT/blob/main/infer_mdt.py to inference images. You need a metric that defines the closeness between generated images and imagenet images, like FID or IS score. Then you can tune the cfg to make the metric as good as possible.