huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.05k stars 5.36k forks source link

Redundant reinitialization of text encoders in train_dreambooth_lora_flux #9358

Open DarkMnDragon opened 2 months ago

DarkMnDragon commented 2 months ago

Describe the bug

In the train_dreambooth_lora_flux.py script, during each call to log_validation, the text encoders text_encoder_one and text_encoder_two are reinitialized. https://github.com/huggingface/diffusers/blob/8ba90aa706a733f45d83508a5b221da3c59fe4cd/examples/dreambooth/train_dreambooth_lora_flux.py#L1768 This occurs even when the text encoders do not need to be trained (if not args.train_text_encoder). This unnecessary reinitialization can lead to inefficient use of resources and may cause CUDA out-of-memory errors, especially in scenarios where VRAM is less than 48 GiB.

Since the validation prompt is fixed (only one prompt is used), we can optimize the process by precomputing the text embeddings during the instance prompt preprocessing. This would allow the model to fit within 40 GiB of VRAM, preventing CUDA OOM issues.

Proposed Fix

To address this issue, add the following code snippet to precompute the validation prompt embeddings only once when the text encoders do not need to be trained and custom instance prompts are not used:

if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
    instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
        args.instance_prompt, text_encoders, tokenizers
    )

    if args.validation_prompt is not None:
        validation_prompt_hidden_states, validation_pooled_prompt_embeds, _ = compute_text_embeddings(
            args.validation_prompt, text_encoders, tokenizers
        )

This change will prevent the unnecessary reinitialization of text encoders and reduce the VRAM usage during training.

Reproduction

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora"

accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-5 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

Logs

No response

System Info

Who can help?

@sayakpaul

sayakpaul commented 2 months ago

Cc: @linoytsaban

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.