CiaoHe / DiM

Diffusion Image Mamba
Other
4 stars 1 forks source link

Install

Install Mamba-ssm

Refer to Vision Mamba

# Install causal_conv1d and mamba
pip install -e causal_conv1d>=1.1.0
pip install -e mamba-1p1p1

Train

# Assign HF_HOME to the cache directory
export HF_HOME="/comp_robot/rentianhe/caohe/cache"

# Train with 4 GPUs
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_mamba.py --model DiM-S/2 --feature-path /shared_space/caohe/DATA/imagenet1k/train_vae --lr 5e-4

The ckpt will be saved under results/xxx-MODEL_VERSION/checkpoints/

Evaluation

Sampling scripts

# Assign HF_HOME to the cache directory
export HF_HOME="/comp_robot/rentianhe/caohe/cache"
# 
MODEL_VERSION=005-DiM-S-2 # The model version
CKPT=0100000 # The checkpoint number
torchrun --nnodes=1 --nproc_per_node=4 sample_ddp.py --model DiM-S/2 --num-fid-samples 50000 --ckpt results/$MODEL_VERSION/checkpoints/$CKPT.pt --sample-dir samples/$MODEL_VERSION-$CKPT --per-proc-batch-size 64

FID evaluation

First need to install related packages following ADM's TensorFlow evaluation suite

IN_50K_REFERENCE=/shared_space/caohe/DATA/imagenet1k/VIRTUAL_imagenet256_labeled.npz
# ! adapt the path to the sample features path
python evaluations/evaluator.py $IN_50K_REFERENCE samples/DiM-B-2-ckpt-0100000/DiM-B-2-0100000-size-256-vae-ema-cfg-1.5-seed-0.npz