huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.37k stars 5.43k forks source link

Can't Train SDXL with LoRA on M2 Mac Mini Using Script #6154

Closed sr5434 closed 10 months ago

sr5434 commented 11 months ago

Describe the bug

Using the instructions in the examples/text_to_image folder's readme_sdxl.md file, I tried to train a Stable Diffusion XL model on a dataset with mixed precision, but failed because Autocast is unsupported on MPS. I cannot train in full precision because then the MPS backend OOMs.

Reproduction

Command:

accelerate launch train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/sdxl-turbo" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="logo-wizard/modern-logo-dataset" --caption_column="text" \
  --resolution=512 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=2 --checkpointing_steps=500 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 --mixed_precision="fp16" \
  --output_dir="sdxl-logos-large" \
  --validation_prompt="a logo of abstract, Yellow-brown heart with an inscription in English on a beige background, moccasin background, darkkhaki, burlywood foreground, minimalism, modern" --push_to_hub

Logs

/Users/samir/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
  warnings.warn(
/Users/samir/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
  warnings.warn(
/Users/samir/Library/Python/3.9/lib/python/site-packages/torch/cuda/amp/grad_scaler.py:125: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn(
12/12/2023 17:44:13 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: mps

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'variance_type', 'dynamic_thresholding_ratio', 'thresholding', 'clip_sample_range'} was not found in config. Values will be initialized to default values.
Traceback (most recent call last):
  File "/Users/samir/Desktop/logo-diffusion/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py", line 1279, in <module>
    main(args)
  File "/Users/samir/Desktop/logo-diffusion/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py", line 934, in main
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/accelerator.py", line 1213, in prepare
    result = tuple(
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/accelerator.py", line 1214, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/accelerator.py", line 1094, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/accelerator.py", line 1280, in prepare_model
    autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/utils/modeling.py", line 1556, in get_mixed_precision_context_manager
    return torch.autocast(device_type=state.device.type, dtype=torch.float16, **autocast_kwargs)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/torch/amp/autocast_mode.py", line 241, in __init__
    raise RuntimeError(
RuntimeError: User specified an unsupported autocast device_type 'mps'
Traceback (most recent call last):
  File "/Users/samir/Library/Python/3.9/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
    simple_launcher(args)
  File "/Users/samir/Library/Python/3.9/lib/python/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/Library/Developer/CommandLineTools/usr/bin/python3', 'train_text_to_image_lora_sdxl.py', '--pretrained_model_name_or_path=stabilityai/sdxl-turbo', '--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix', '--dataset_name=logo-wizard/modern-logo-dataset', '--caption_column=text', '--resolution=512', '--random_flip', '--train_batch_size=1', '--num_train_epochs=2', '--checkpointing_steps=500', '--learning_rate=1e-04', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--seed=42', '--mixed_precision=fp16', '--output_dir=sdxl-logos-large', '--validation_prompt=a logo of abstract, Yellow-brown heart with an inscription in English on a beige background, moccasin background, darkkhaki, burlywood foreground, minimalism, modern', '--push_to_hub']' returned non-zero exit status 1.```

System Info

Who can help?

@sayakpaul @pcuenca

sayakpaul commented 11 months ago

If autocasting is not supported for MPS for certain ops, we cannot do much.

sr5434 commented 11 months ago

Could I use .half()? Would that even work?

sayakpaul commented 11 months ago

That we cannot guarantee. I suggest you stick to full-precision and reduce batch size and/or resolution. I don't MPS backend should be used for training as it can be quite unreliable.

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.