kohya-ss / sd-scripts

Apache License 2.0
5.14k stars 856 forks source link

OOM using AdamW8bit since recent update #1700

Open Tophness opened 1 week ago

Tophness commented 1 week ago

I'm getting OOM errors using adamw8bit on my 16GB 4080 w/32gb system ram now. It normally doesn't fit in it's vram anyway, but on commit 0005867ba509d2e1a5674b267e8286b561c0ed71 it was able to still train despite spilling over into shared memory, and the difference was night and day from using Adafactor, converging faster with much better results. Is there any way to prevent it from using enough to OOM?

accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path "flux1-dev-fp8.safetensors" --clip_l "clip_l.safetensors" --t5xxl "t5xxl_fp16.safetensors" --ae "ae.safetensors" --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 16 --optimizer_type adamw8bit --learning_rate 1e-4 --fp8_base --max_train_epochs 400 --save_every_n_epochs 2 --dataset_config model.toml --output_dir outputs --output_name model --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --sample_every_n_epochs 1 --sample_prompts prompts.txt --network_args "loraplus_unet_lr_ratio=4" "train_t5xxl=True" "split_qkv=True"

kohya-ss commented 1 week ago

The max memory usage should not change in the recent commits, but there may be some factor that makes it slightly higher. If you are not shuffling captions, can you specify --cache_text_encoder_outputs?

Tophness commented 6 days ago

If you are not shuffling captions, can you specify --cache_text_encoder_outputs?

ValueError: T5XXL is trained, so cache_text_encoder_outputs cannot be used

kohya-ss commented 6 days ago

It seems that you are training not only DiT but also Text Encoders, which requires a lot of memory. Please train only DiT with option --network_train_unet_only. --cache_text_encoder_outputs is also needed to reduce memory usage.

Tophness commented 6 days ago

It seems that you are training not only DiT but also Text Encoders, which requires a lot of memory. Please train only DiT with option --network_train_unet_only. --cache_text_encoder_outputs is also needed to reduce memory usage.

I was training only the DiT for a long time and couldn't get good results since it's a new concept that doesn't exist in the text encoder

kohya-ss commented 4 days ago

Unfortunately, it will be difficult to train FLUX, CLIP-L, and T5XXL with 16GB of memory. It may be possible to train only FLUX and CLIP-L by removing train_t5xxl=True and adding --cache_text_encoder_outputs.