kohya-ss / sd-scripts

Apache License 2.0
4.57k stars 775 forks source link

SDXL Dreambooth OOM on 24GB VRAM #1152

Open jellersby opened 4 months ago

jellersby commented 4 months ago

I've had good success training LoRAs, but the dreambooth script (sdxl_train.py) runs out of RAM on the first step. Any ideas? I feel like 24GB (7900 XTX) should be more than enough.

Resolution: 1024 Batch size: 1

./sdxl_train.py \
    --pretrained_model_name_or_path=/home/jellersby/ml/models/Stable-diffusion/sd_xl_base_1.0.safetensors \
    --vae=/home/jellersby/ml/models/VAE/sdxl_vae.safetensors \
    --no_half_vae \
    --mixed_precision=bf16 \
    --full_bf16 \
    --sdpa \
    --noise_offset=0.0357 \
    --gradient_checkpointing \
    --gradient_accumulation_steps=1 \
    --output_dir=/home/jellersby/ml/models/Stable-diffusion/train \
    --output_name=ohwx_09 \
    --save_model_as=safetensors \
    --dataset_config=/home/jellersby/ml/training/dataset.toml \
    --logging_dir=/home/jellersby/ml/training/logs \
    --log_tracker_name=ohwx_09 \
    --log_prefix=ohwx_09 \
    --cache_latents \
    --max_train_epochs=200 \
    --learning_rate=1 \
    --save_every_n_epochs=1 \
    --lr_scheduler=cosine \
    --optimizer_args='"d_coef=1" "weight_decay=0.01" "betas=0.9,0.99" "safeguard_warmup=True" "use_bias_correction=True" "decouple=True"' \
    --optimizer_type=Prodigy
CCpt5 commented 4 months ago

AFAIK you need an NVIDIA card (CUDA) to train SDXL.

jellersby commented 4 months ago

AFAIK you need an NVIDIA card (CUDA) to train SDXL.

I've been able to train SDXL LoRAs without any problem.

kohya-ss commented 4 months ago

Please try to use adafactor optimizer, with optional args "scale_parameter=False" "relative_step=False" "warmup_init=False". It may use less VRAM than Prodigy.

Yo1up commented 4 months ago

AFAIK you need an NVIDIA card (CUDA) to train SDXL.

I've been able to train SDXL LoRAs without any problem.

Yeah, LoRAs are specifically designed to use significantly less VRAM and compute due to the lower-rank backprop passes. the person in the initial issue has an AMD card and is attempting a full FT of SDXL, something that is still very finicky to get to work even with the optimizations on Nvidia cards.