huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
24.18k stars 4.99k forks source link

SD3 training OOM on 4090 #8560

Closed inspire-boy closed 2 weeks ago

inspire-boy commented 1 month ago

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacty of 23.64 GiB of which 9.69 MiB is free. Process 20296 has 23.63 GiB memory in use. Of the allocated memory 22.81 GiB is allocated by PyTorch, and 362.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation

my params: accelerate launch train_dreambooth_sd3.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --mixed_precision="fp16" \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --learning_rate=1e-4 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --max_train_steps=500 \ --validation_prompt="A photo of sks dog in a bucket" \ --validation_epochs=25 \ --seed="42" \

It make OOM. If can someone give a lowvram example(accelerate choices included),thanks!

tolgacangoz commented 1 month ago

You can add --gradient_checkpointing and --use_8bit_adam flags. Also, you can try --gradient_accumulation_steps=8. If it still gives OOM, you can replace bnb.optim.AdamW8bit with bnb.optim.PagedAdamW8bit when using --use_8bit_adam; but this might be slower.

bghira commented 1 month ago

better to train a LoRA on 24G VRAM with this script, or use something else that does pre-processing eg. simpletuner

xings19 commented 1 month ago

I also encountered this problem. In fact, the OOM was not during training, but during validation.

inspire-boy commented 1 month ago

You can add --gradient_checkpointing and --use_8bit_adam flags. Also, you can try --gradient_accumulation_steps=8. If it still gives OOM, you can replace bnb.optim.AdamW8bit with bnb.optim.PagedAdamW8bit when using --use_8bit_adam; but this might be slower.

accelerate launch train_dreambooth_sd3.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --mixed_precision="bf16" \ --instance_prompt="a photo of sks dog" \ --resolution=768 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --learning_rate=1e-4 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --max_train_steps=500 \ --validation_prompt="A photo of sks dog in a bucket" \ --validation_epochs=25 \ --seed="42" \ --gradient_checkpointing \ --use_8bit_adam \ By using params above, also OOM. I don't know is there param else I can adjust? xformers? ^^

xings19 commented 1 month ago

@inspire-boy Did you encounter OOM during training or during validation?

inspire-boy commented 1 month ago

better to train a LoRA on 24G VRAM with this script, or use something else that does pre-processing eg. simpletuner

I use a 4090-24G..

inspire-boy commented 1 month ago

@inspire-boy Did you encounter OOM during training or during validation?

in the begining Steps: 0%| | 0/500 [00:08<?, ?it/s]

tolgacangoz commented 1 month ago

Did you try PagedAdamW8bit?

inspire-boy commented 1 month ago

accelerate launch train_dreambooth_lora_sd3.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ --mixed_precision="fp16" \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --learning_rate=1e-5 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --max_train_steps=500 \ --validation_prompt="A photo of sks dog in a bucket" \ --validation_epochs=25 \ --gradient_checkpointing \ --use_8bit_adam

in train_dreambooth_lora_sd3.py line 1239 : if args.optimizer.lower() == "adamw": ................

        **optimizer_class = bnb.optim.PagedAdamW8bit**
    else:
        optimizer_class = torch.optim.AdamW

-------------------------------------------------------------------- erros: File "/data/miniconda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply module._apply(fn) [Previous line repeated 4 more times] File "/data/miniconda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 833, in _apply param_applied = fn(param) File "/data/miniconda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1158, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 160.00 MiB. GPU 0 has a total capacty of 23.64 GiB of which 21.69 MiB is free. Process 20268 has 23.62 GiB memory in use. Of the allocated memory 21.71 GiB is allocated by PyTorch, and 1.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Steps: 0%|â–‹ | 2/500 [00:34<2:21:29, 17.05s/it, loss=0.119, lr=1e-5] Traceback (most recent call last): File "/data/miniconda/envs/torch/bin/accelerate", line 8, in sys.exit(main()) File "/data/miniconda/envs/torch/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main args.func(args) File "/data/miniconda/envs/torch/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1097, in launch_command simple_launcher(args) File "/data/miniconda/envs/torch/lib/python3.10/site-packages/accelerate/commands/launch.py", line 703, in simple_launcher raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) subprocess.CalledProcessError: Command '['/data/miniconda/envs/torch/bin/python', 'train_dreambooth_lora_sd3.py', '--pretrained_model_name_or_path=/data/stable-diffusion-3-medium-diffusers', '--instance_data_dir=dog', '--output_dir=trained-sd3', '--mixed_precision=fp16', '--instance_prompt=a photo of sks dog', '--resolution=1024', '--train_batch_size=1', '--gradient_accumulation_steps=4', '--learning_rate=1e-5', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--validation_prompt=A photo of sks dog in a bucket', '--validation_epochs=25', '--gradient_checkpointing', '--use_8bit_adam']' returned non-zero exit status 1.

I edit this python code,but still OOM...

inspire-boy commented 1 month ago

Did you try PagedAdamW8bit?

If it support xformers?

tolgacangoz commented 1 month ago

Currently, SD3 doesn't support xformers: https://github.com/huggingface/diffusers/issues/8535 Btw, what is your PyTorch version? Is it one of the latest versions?

inspire-boy commented 1 month ago

Currently, SD3 doesn't support xformers: #8535 Btw, what is your PyTorch version? Is it one of the latest versions? @tolgacangoz tolgacangoz Ubuntu 22.04.4 LTS + vram/4090 + ram/42G Python 3.10.14

dataset edit to:1000px*1000px

nvidia-smi Sat Jun 15 02:38:33 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 4090 Off | 00000000:1A:00.0 Off | Off | | 30% 24C P8 23W / 450W | 1MiB / 24564MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+


Name: torch Version: 2.1.2+cu118 Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration Home-page: https://pytorch.org/ Author: PyTorch Team Author-email: packages@pytorch.org License: BSD-3 Location: /data/miniconda/envs/torch/lib/python3.10/site-packages Requires: filelock, fsspec, jinja2, networkx, sympy, triton, typing-extensions Required-by: accelerate, peft, torchaudio, torchvision, xformers

Name: torchvision Version: 0.16.2+cu118 Summary: image and video datasets and models for torch deep learning Home-page: https://github.com/pytorch/vision Author: PyTorch Core Team Author-email: soumith@pytorch.org License: BSD Location: /data/miniconda/envs/torch/lib/python3.10/site-packages Requires: numpy, pillow, requests, torch Required-by:

Name: xformers Version: 0.0.23.post1+cu118 Summary: XFormers: A collection of composable Transformer building blocks. Home-page: https://facebookresearch.github.io/xformers/ Author: Facebook AI Research Author-email: oncall+xformers@xmail.facebook.com License: Location: /data/miniconda/envs/torch/lib/python3.10/site-packages Requires: numpy, torch Required-by:

Name: bitsandbytes Version: 0.41.3.post2 Summary: k-bit optimizers and matrix multiplication routines. Home-page: https://github.com/TimDettmers/bitsandbytes Author: Tim Dettmers Author-email: dettmers@cs.washington.edu License: MIT Location: /data/miniconda/envs/torch/lib/python3.10/site-packages Requires: Required-by:

I fell it happend in this output: Generating 4 images with prompt: A photo of sks dog in a bucket. Traceback (most recent call last): File "/data/diffusers/examples/dreambooth/train_dreambooth_lora_sd3.py", line 1665, in main(args) ........ OOM ........

asomoza commented 1 month ago

yeah, with the validation prompt for the moment, you'll need at least 31GB of VRAM, with something like this:

accelerate launch examples/dreambooth/train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path="models/stable_diffusion_3_medium/" \
--instance_data_dir="./datasets/dog"  \
--output_dir="./outputs/lora/dog/"  \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 --train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--optimizer="AdamW" \
--use_8bit_adam \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--seed="42" 

you can train with 21GB VRAM.

asomoza commented 1 month ago

Also found that you're using --gradient_accumulation_steps=4 but not --gradient_checkpointing . You will still get the OOM with it though.

inspire-boy commented 1 month ago

yeah, with the validation prompt for the moment, you'll need at least 31GB of VRAM, with something like this:

accelerate launch examples/dreambooth/train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path="models/stable_diffusion_3_medium/" \
--instance_data_dir="./datasets/dog"  \
--output_dir="./outputs/lora/dog/"  \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 --train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--optimizer="AdamW" \
--use_8bit_adam \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--seed="42" 

you can train with 21GB VRAM.

It'worked without validate prompts! I got "pytorch_lora_weights.safetensors". by the way, is there any way to optimize validation to reduce video memory? I tried 223 but still OOM. and Can this lora be used in the officially provided comfyui workflow? I notice Load Lora node can't find it very thankful.

inspire-boy commented 1 month ago

got prompt lora key not loaded: transformer.transformer_blocks.0.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.0.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.1.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.10.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.11.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.12.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.13.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.14.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.15.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.16.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.17.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.18.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.19.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.2.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.20.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.21.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.22.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.23.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.3.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.4.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.5.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.6.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.7.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.8.attn.to_v.lora_B.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_k.lora_A.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_k.lora_B.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_q.lora_A.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_q.lora_B.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_v.lora_A.weight lora key not loaded: transformer.transformer_blocks.9.attn.to_v.lora_B.weight Requested to load SD3 Loading 1 new model 100%|███████████

inspire-boy commented 1 month ago

import torch from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained("/data/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)

lora_id = "/data/diffusers/examples/dreambooth/trained-sd3/checkpoint-500" pipe.load_lora_weights(lora_id)

pipe = pipe.to("cuda")

pipe.enable_xformers_memory_efficient_attention()

pipe.enable_sequential_cpu_offload()

pipe.enable_model_cpu_offload()

image = pipe( "A photo of sks dog inside a bottle containing a galaxy. The bottle printed text: SD3 Lora Dog", negative_prompt="low quality", num_inference_steps=30, guidance_scale=6.0, generator=torch.manual_seed(42) ).images[0] image

20240615162626

prompt: A photo of sks dog in a bucket printed text that say:SD3 Lora Dog 20240615162635

asomoza commented 1 month ago

glad you got it working, I don't know why did you get the layers error, for me it worked from the beginning but I did use the code from some PRs that fix the loss and the gradient accumulation.

IMO the VRAM usage with validation shouldn't be that high but this is just the first iteration of the training scripts, as you can see we're still updating it with the help of the community.

Also comfyui added a commit that enables to load the diffusers SD3 lora, so it should work there too.

inspire-boy commented 1 month ago

glad you got it working, I don't know why did you get the layers error, for me it worked from the beginning but I did use the code from some PRs that fix the loss and the gradient accumulation.

IMO the VRAM usage with validation shouldn't be that high but this is just the first iteration of the training scripts, as you can see we're still updating it with the help of the community.

Also comfyui added a commit that enables to load the diffusers SD3 lora, so it should work there too.

just ignore the errors,it's comfyui lora loader issue.I will update the code. Very thank.your work is timely and valuable.

C0nsumption commented 1 month ago

Hey are you still able to get the script running? I keep having errors about the text encoder loading (made a separate issue.) The dreambooth training works but the Lora training just wont at all for me.

qsunyuan commented 1 month ago

I currently have four 4090 GPUs, each with a24GB of memory. Can full-finetuning (even without using some quantization tricks, train_dreambooth_lora_sd3.py or train_dreambooth_sd3.py) be distributed to run this script?

inspire-boy commented 1 month ago

Hey are you still able to get the script running? I keep having errors about the text encoder loading (made a separate issue.) The dreambooth training works but the Lora training just wont at all for me.

There are many causes.pls list environments and errors waiting someone known it.

sayakpaul commented 2 weeks ago

Seems like the issue is solved. For text encoder training, we have this opened already: https://github.com/huggingface/diffusers/issues/8726