huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.49k stars 5.28k forks source link

TorchDynamo with Pipeline loading #3403

Closed thepowerfuldeez closed 1 year ago

thepowerfuldeez commented 1 year ago

Describe the bug

Hi! I am using accelerate + diffusers for fine-tuning of stable diffusion using this script: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py I have pytorch 2.0 and accelerate recently supports torch dynamo for improving speed of training. However, when I set dynamo_backend = "inductor" in Accelerator class, I am getting error in validation stage:

line 67: https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/examples/text_to_image/train_text_to_image.py#L67-L76 of type: <class 'torch._dynamo.eval_frame.OptimizedModule'>, but should be <class 'diffusers.models.modeling_utils.ModelMixin'>

The issue is for unet

Reproduction

https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/examples/text_to_image/train_text_to_image.py#L67

Logs

No response

System Info

pytorch==2.0.1 diffusers==0.17.0.dev0 accelerate==0.19.0

thepowerfuldeez commented 1 year ago

related https://github.com/huggingface/diffusers/issues/2709

patrickvonplaten commented 1 year ago

Hey @thepowerfuldeez,

Thanks for the issue! Could you add a reproducible code snippet (the command you use to start training + the accelerate config?)

This would be of great help :-)

Also cc @sayakpaul

thepowerfuldeez commented 1 year ago

@patrickvonplaten Sure! I start training with aforementioned script file as accelerate launch train_text_to_image.py specifying mixed_precision=bf16 and dynamo_backend=inductor only. It’s auto detected 8 gpus so I’m using num_processes=8 First iteration it hangs for a while as it’s compiling the code with dynamo, and after 2000 steps i perform validation as in the script where the exception is evoked.

I’m afraid I can’t share data to reproduce. However, I could try to prepare gist with commands on sample data to try if you would not be able to reproduce bug.

patrickvonplaten commented 1 year ago

@thepowerfuldeez I can reproduce your error. Running this script:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export MODEL_NAME="hf-internal-testing/tiny-stable-diffusion-pipe"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=64 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --max_train_steps=10 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

with the following accelerate config:

compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
  dynamo_mode: default
  dynamo_use_dynamic: false
  dynamo_use_fullgraph: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

and the following env:

- `diffusers` version: 0.17.0.dev0
- Platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.35
- Python version: 3.10.6
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Huggingface_hub version: 0.14.1
- Transformers version: 4.30.0.dev0
- Accelerate version: 0.20.0.dev0
- xFormers version: not installed
- Using GPU in script?:  yes
- Using distributed or parallel set-up in script?: yes, 2 RTX4090

reproduces your error above. It's fixed with this PR: https://github.com/huggingface/accelerate/pull/1437 . The problem here is that we have nested wrapping of both torch compile's optimized module and data parallel.

I'll discuss with the accelerate team if the PR is ok.

thepowerfuldeez commented 1 year ago

@patrickvonplaten thank you very much for proper testing!

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.