huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.99k stars 5.35k forks source link

When fine-tuning sdxl I get Out of Memory when mapping the dataset #6362

Closed AeroDEmi closed 10 months ago

AeroDEmi commented 10 months ago

Describe the bug

When using train_text_to_image_sdxl.py I'm getting out of memory when mapping the dataset:

Resolving data files: 100% 1512/1512 [00:00<00:00, 44098.99it/s]
Map:   0% 0/1511 [00:44<?, ? examples/s]
Traceback (most recent call last):
  File "/content/drive/MyDrive/Colab/SD_LoRA/train_text_to_image_sdxl.py", line 1281, in <module>
    main(args)
  File "/content/drive/MyDrive/Colab/SD_LoRA/train_text_to_image_sdxl.py", line 891, in main
    train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
...
...
...
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 752.00 MiB. GPU 0 has a total capacty of 39.56 GiB of which 535.44 MiB is free.

Reproduction

accelerate launch train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path="segmind/Segmind-Vega" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="lambdalabs/pokemon-blip-captions" \
  --enable_xformers_memory_efficient_attention \
  --resolution=1024 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=10000 \
  --use_8bit_adam \
  --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir="vega-pokemon-model" 

Logs

No response

System Info

I'm using Google Colab with the A100 40GBS RAM

Who can help?

@sayakpaul @patrickvonplaten

sayakpaul commented 10 months ago

Can you pass on your Google Colab?

AeroDEmi commented 10 months ago

Sure!, I only have three cells of code:

!pip install accelerate -q
!pip install torchvision -q
!pip install transformers -q
!pip install datasets -q
!pip install ftfy -q
!pip install tensorboard -q
!pip install Jinja2 -q
!pip install diffusers -q
!pip install peft -q
from accelerate.utils import write_basic_config

write_basic_config()
!accelerate launch /content/drive/MyDrive/Colab/SD_LoRA/train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="lambdalabs/pokemon-blip-captions" \
  --resolution=1024 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=5 --checkpointing_steps=635 \
  --learning_rate=1e-05 --lr_scheduler="constant" --lr_warmup_steps=100 \
  --mixed_precision="fp16" \
  --seed=42 \
  --output_dir="/content/pokemon"
AeroDEmi commented 10 months ago

Seems that it hangs here: train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) in the script train_text_to_image_sdxl.py

sayakpaul commented 10 months ago

Can you ensure the use of the latest diffusers script cloned from main?

AeroDEmi commented 10 months ago

Let me pull again from main

AeroDEmi commented 10 months ago

Still having the error

AeroDEmi commented 10 months ago

It works with the False flag in train_dataset = train_dataset.map(compute_embeddings_fn, batched=False, new_fingerprint=new_fingerprint)

Now, it hangs around the 500th iteration

sayakpaul commented 10 months ago

500th iteration of training?

AeroDEmi commented 10 months ago

In the caching of the dataset in the first mapping. But I was using a 24gbs GPU, let me try it with 40gbs

AeroDEmi commented 10 months ago

It seems it was the caching of the tokenizer and the vae. I ran out of RAM. I increased the RAM and the problem went away. I wonder if there's a way to not use ram and use the storage

AeroDEmi commented 10 months ago

@sayakpaul is it possible to not load everything into RAM if my dataset is really big?