sony / ctm

227 stars 12 forks source link

How to reach FID 1.92 on Imagenet 64 #8

Open Schwartz-Zha opened 3 months ago

Schwartz-Zha commented 3 months ago

I downloaded the pretrained Imagenet 64 checkpoint and use the provided sampling commands (with slight modification to make it run on my machine)

export OMPI_COMM_WORLD_RANK=0
export OMPI_COMM_WORLD_LOCAL_RANK=0
export OMPI_COMM_WORLD_SIZE=8

MODEL_FLAGS="--data_name=imagenet64 --class_cond=True --eval_interval=1000 --save_interval=1000 --num_classes=1000 --eval_batch=250 --eval_fid=True --eval_similarity=False --check_dm_performance=False --log_interval=100"

# CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python ./code/image_sample.py $MODEL_FLAGS --class_cond=True --num_classes=1000 --out_dir ./ctm-sample-paths/ctm_bs_1440/ --model_path=./ctm-runs/ctm_bs_1440/ema_0.999_006000.pt --training_mode=ctm --class_cond=True --eval_num_samples=6400 --batch_size=800 --device_id=0 --stochastic_seed=True --save_format=npz --ind_1=36 --ind_2=20 --use_MPI=True --sampler=exact --sampling_steps=1

CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python ./code/image_sample.py $MODEL_FLAGS --class_cond=True --num_classes=1000 --out_dir ./ctm-sample-paths/ctm_bs_1440_author/ --model_path=./ckpts/ema_0.999_049000.pt --training_mode=ctm --class_cond=True --eval_num_samples=6400 --batch_size=800 --device_id=0 --stochastic_seed=True --save_format=npz --ind_1=36 --ind_2=20 --use_MPI=True --sampler=exact --sampling_steps=1

And I tested it according to the provided instrutions.

CUDA_VISIBLE_DEVICES=0 python code/evaluations/evaluator.py     ref-statistics/VIRTUAL_imagenet64_labeled.npz     ctm-sample-paths/ctm_bs_1440_author/ctm_exact_sampler_1_steps_049000_itrs_0.999_ema_/

But the performance reading is significantly worse.

Inception Score: 68.49456024169922FID: 6.839029518808786
sFID: 22.465793478419187
Precision: 0.7965625
Recall: 0.6551

How exactly could I read FID 1.92 ? Even with pretrained model directly from the author?