huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.55k stars 5.29k forks source link

Traing lora: accelerate launch train_text_to_image_lora_sdxl.py gives TypeError: main() missing 1 required positional argument: 'args' #5849

Closed PawKanarek closed 8 months ago

PawKanarek commented 11 months ago

Describe the bug

Hello 🤗. I'm trying to train a LoRA SDXL on Google Cloud TPU v3-8 machine with this script: train_text_to_image_lora_sdxl.py. I followed setup guide from diffusers/examples/text_to_image/README_sdxl.md but im stuck at very beginning.

I also attempted to run this script on my local Apple M1 Pro. Despite the mps device not being supported, it appears that the script successfully bypasses the error encountered on the Google Cloud TPU, which I mention in this ticket.

Stacktrace suggest that accelerate cannot launch PrepareForLaunch function. Details below. ⬇️

Any help and ideas to move forward would be appreciated <3

Reproduction

Run that CLI commands in diffusers/examples/text_to_image/ directory. Those commands are copy pasted from README_sdxl.md, for simplicity i deleted parameters --report_to="wandb" and --push_to_hub

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=1024 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=2 --checkpointing_steps=500 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --seed=42 \
  --output_dir="sd-pokemon-model-lora-sdxl" \
  --validation_prompt="cute dragon creature"

Logs

(lora) raix@XYZ:/mnt/disks/persist/repos/diffusers/examples/text_to_image$ accelerate launch train_text_to_image_lora_sdxl.py   --pretrained_model_name_or_path=$MODEL_NAME   --pretrained_vae_model_name_or_path=$VAE_NAME   --dataset_name=$DATASET_NAME --caption_column="text"   --resolution=1024 --random_flip   --train_batch_size=1   --num_train_epochs=2 --checkpointing_steps=500   --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0   --mixed_precision="fp16"   --seed=42   --output_dir="sd-pokemon-model-lora-sdxl"   --validation_prompt="cute dragon creature"
Traceback (most recent call last):
  File "/home/raix/miniconda3/envs/lora/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/accelerate/commands/launch.py", line 990, in launch_command
    tpu_launcher(args)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/accelerate/commands/launch.py", line 733, in tpu_launcher
    xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 198, in spawn
    return _run_singleprocess(spawn_fn)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 102, in _run_singleprocess
    return fn(*args, **kwargs)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 178, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/home/raix/miniconda3/envs/lora/lib/python3.8/site-packages/accelerate/utils/launch.py", line 554, in __call__
    self.launcher(*args)
TypeError: main() missing 1 required positional argument: 'args'

System Info

As suggested in README_lora.md im using diffusers from latest main (commit sha: c896b841e48b65e561800f829c546f4cf047e634)

Output of $ diffusers-cli env

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- `diffusers` version: 0.24.0.dev0
- Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.17
- Python version: 3.8.18
- PyTorch version (GPU?): 2.1.1+cu121 (False)
- Huggingface_hub version: 0.19.4
- Transformers version: 4.35.2
- Accelerate version: 0.24.1
- xFormers version: not installed
- Using GPU in script?: <fill in> => no
- Using distributed or parallel set-up in script?: <fill in> => TPU 

Accelerate config file:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: TPU
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

This is brand new, fresh and clean conda environment with packages listed below. Output of $ conda list

# packages in environment at /home/raix/miniconda3/envs/lora:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   2.0.0                    pypi_0    pypi
accelerate                0.24.1                   pypi_0    pypi
aiohttp                   3.8.6                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
ca-certificates           2023.08.22           h06a4308_0  
cachetools                5.3.2                    pypi_0    pypi
certifi                   2023.7.22                pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
cloud-tpu-client          0.10                     pypi_0    pypi
datasets                  2.15.0                   pypi_0    pypi
diffusers                 0.24.0.dev0              pypi_0    pypi
dill                      0.3.7                    pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
frozenlist                1.4.0                    pypi_0    pypi
fsspec                    2023.10.0                pypi_0    pypi
ftfy                      6.1.1                    pypi_0    pypi
google-api-core           1.34.0                   pypi_0    pypi
google-api-python-client  1.8.0                    pypi_0    pypi
google-auth               2.23.4                   pypi_0    pypi
google-auth-httplib2      0.1.1                    pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
googleapis-common-protos  1.61.0                   pypi_0    pypi
grpcio                    1.59.2                   pypi_0    pypi
httplib2                  0.22.0                   pypi_0    pypi
huggingface-hub           0.19.4                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
importlib-metadata        6.8.0                    pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libtpu-nightly            0.1.dev20230825+default          pypi_0    pypi
markdown                  3.5.1                    pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
multiprocess              0.70.15                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.1                      pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.3.101                 pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
oauth2client              4.1.3                    pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
openssl                   3.0.12               h7f8727e_0  
packaging                 23.2                     pypi_0    pypi
pandas                    2.0.3                    pypi_0    pypi
pillow                    10.1.0                   pypi_0    pypi
pip                       23.3             py38h06a4308_0  
protobuf                  3.20.3                   pypi_0    pypi
psutil                    5.9.6                    pypi_0    pypi
pyarrow                   14.0.1                   pypi_0    pypi
pyarrow-hotfix            0.5                      pypi_0    pypi
pyasn1                    0.5.0                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
python                    3.8.18               h955ad1f_0  
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2023.3.post1             pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
regex                     2023.10.3                pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
safetensors               0.4.0                    pypi_0    pypi
setuptools                68.0.0           py38h06a4308_0  
six                       1.16.0                   pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
sympy                     1.12                     pypi_0    pypi
tensorboard               2.14.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.15.0                   pypi_0    pypi
torch                     2.1.1                    pypi_0    pypi
torch-xla                 2.1.0                    pypi_0    pypi
torchvision               0.16.1                   pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
transformers              4.35.2                   pypi_0    pypi
triton                    2.1.0                    pypi_0    pypi
typing-extensions         4.8.0                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
uritemplate               3.0.1                    pypi_0    pypi
urllib3                   2.1.0                    pypi_0    pypi
wcwidth                   0.2.10                   pypi_0    pypi
werkzeug                  3.0.1                    pypi_0    pypi
wheel                     0.41.2           py38h06a4308_0  
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.2                h5eee18b_0  
yarl                      1.9.2                    pypi_0    pypi
zipp                      3.17.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0 

Who can help?

@sayakpaul

yiyixuxu commented 11 months ago

Hi @PawKanarek

our pytorch training examples are intended to run on GPU not TPUs. cc @sayakpaul to confirm

PawKanarek commented 11 months ago

Thank you @yiyixuxu. This would explain why I'm struggling so hard with it. I will try now to train stable diffusion v1-4, since it's mentined that it can be done on TPU. I found more information about training with Flax/Jax here in this README.md.

I'm keeping this ticket open because perhaps @sayakpaul can provide some hints on what I need to do, to train SDXL on TPUs.

sayakpaul commented 10 months ago

For maximizing the performance on TPUs, I welcome you to check our JAX training scripts.

PawKanarek commented 10 months ago

Thanks @sayakpaul for your response and for directing me to the JAX training scripts, studying them gave me strength and knowledge to work on new script for training SDXL with flax called train_text_to_image_flax_sdxl.py :)

However, after after some bigger and smaller problems that i could solve on my own, now I'm stuck.

I'm encountering a problem where I cannot create a TrainState for the UNet of SDXL. I'm getting a strange runtime error XlaRuntimeError indicating that program cannot allocate 50MB, even though there's still about 310GB available on the machine (Google TPU v3-8) The error says:

# RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0)

Im providing a minimal functioning code that is needed to reproduce this issue:

import jax.numpy as jnp
import optax
from flax.training import train_state
from diffusers import FlaxUNet2DConditionModel

print("initialize sdxl-base FlaxUNet2DConditionModel")
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    "pcuenq/stable-diffusion-xl-base-1.0-flax",
    from_pt=False,
    subfolder="unet",
    dtype=jnp.bfloat16,
)
constant_scheduler = optax.constant_schedule(1e-5)
adamw = optax.adamw(learning_rate=constant_scheduler, b1=0.9, b2=0.999, eps=1e-08, weight_decay=1e-2)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), adamw)

# I cannot create TrainState for SDXL on TPU v3-8 with 335 GB RAM
# It states that it have problem with allocating 50M of memory, despite fact that ~310GB is still free.
# to this point program will allocate about ~20GB of memory, and  will throw error RESOURCE_EXHAUSTED
print("******* TRYING TO CREATE TrainState AND THIS WILL THROW RESOURCE_EXHAUSTED")
state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
# throws:  RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0)
print(f"Never goes here :( {state=}")

The full code to generate this error is in my new script:603-616, but this minimal code snippet is enough to create this runtime error.

Why I'm thinking this is weird error? Because i know that i used up to 60GB or RAM more on this machine and i did not recieve any errors, and now I'm getting RESOURCE_EXHAUSTED when only 20GB is used. Any idea or guidance in solving this problem would be greatly appreciated. Thank you!

sayakpaul commented 10 months ago

Cc: @pcuenca

PawKanarek commented 10 months ago

The content of this comment was irrelevant to the discussion. I edited it for clarity

old comment I changed the reproduction code a bit to show that there is enough memory on this machine. ```python import jaxlib.xla_extension import optax from flax.training import train_state from memory_profiler import profile from diffusers import FlaxUNet2DConditionModel @profile def main(): models = [] for i in range(10): print(f"step: {i}") unet, params = FlaxUNet2DConditionModel.from_pretrained( "pcuenq/stable-diffusion-xl-base-1.0-flax", subfolder="unet", ) models.append((unet, params)) adamw = optax.adamw(learning_rate=optax.constant_schedule(1e-5)) try: state = train_state.TrainState.create(apply_fn=unet.__call__, params=params, tx=adamw) print(f"Never goes here :( {state=}") except jaxlib.xla_extension.XlaRuntimeError as e: # RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) print(e) if __name__ == "__main__": main() ``` And profile output: ``` Line # Mem usage Increment Occurrences Line Contents ============================================================= 9 486.8 MiB 486.8 MiB 1 @profile 10 def main(): 11 486.8 MiB 0.0 MiB 1 models = [] 12 116280.8 MiB 0.0 MiB 11 for i in range(10): 13 105918.0 MiB 0.0 MiB 10 print(f"step: {i}") 14 116280.8 MiB 115791.7 MiB 20 unet, params = FlaxUNet2DConditionModel.from_pretrained( 15 105918.0 MiB -0.0 MiB 10 "pcuenq/stable-diffusion-xl-base-1.0-flax", 16 105918.0 MiB 0.0 MiB 10 subfolder="unet", 17 ) 18 116280.8 MiB 0.0 MiB 10 models.append((unet, params)) 19 116280.8 MiB 0.0 MiB 10 adamw = optax.adamw(learning_rate=optax.constant_schedule(1e-5)) 20 116280.8 MiB 0.0 MiB 10 try: 21 116280.8 MiB 2.4 MiB 10 state = train_state.TrainState.create(apply_fn=unet.__call__, params=params, tx=adamw) 22 print(f"Never goes here :( {state=}") 23 116280.8 MiB 0.0 MiB 10 except jaxlib.xla_extension.XlaRuntimeError as e: 24 # RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) 25 116280.8 MiB -0.1 MiB 10 print(e) ``` It seems that there is no problem in creating `10 unet models` which together consume `105GB` of memory, however there is always a problem in allocating `50MB` when creating `TrainState`. That feels wrong to me. In next steps i will try to measure memory with [pprof](https://jax.readthedocs.io/en/latest/device_memory_profiling.html) Once again, all ideas that will help me push this issue forward will be greatly appreciated! Full output logs if anyone interested (click arrow to show)
logs ``` /mnt/disks/persist/repos$ python -m memory_profiler diffusers/examples/text_to_image/test_train_state_sdxl_flax.py step: 0 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x8b9e000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 1 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 2 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 3 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 4 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 5 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 6 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 7 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 8 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) step: 9 The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. tcmalloc: large alloc 10269917184 bytes == 0x4fd33c000 @ 0x7fc381393680 0x7fc3813b4824 0x4d562f 0x5913a7 0x4e61e5 0x5ee2da 0x590f5b 0x4e8cfb 0x4dfa44 0x4a12ee 0x430b16 0x4d70d1 0x4f50db 0x4dfa44 0x43103e 0x4e81a6 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4f75ca 0x4da183 0x4d70d1 0x4e823c 0x4d84a9 0x4d70d1 0x585e29 0x585deb 0x589f31 0x4e8cfb 0x4d84a9 RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) Filename: diffusers/examples/text_to_image/test_train_state_sdxl_flax.py Line # Mem usage Increment Occurrences Line Contents ============================================================= 9 486.8 MiB 486.8 MiB 1 @profile 10 def main(): 11 486.8 MiB 0.0 MiB 1 models = [] 12 116280.8 MiB 0.0 MiB 11 for i in range(10): 13 105918.0 MiB 0.0 MiB 10 print(f"step: {i}") 14 116280.8 MiB 115791.7 MiB 20 unet, params = FlaxUNet2DConditionModel.from_pretrained( 15 105918.0 MiB -0.0 MiB 10 "pcuenq/stable-diffusion-xl-base-1.0-flax", 16 105918.0 MiB 0.0 MiB 10 subfolder="unet", 17 ) 18 116280.8 MiB 0.0 MiB 10 models.append((unet, params)) 19 116280.8 MiB 0.0 MiB 10 adamw = optax.adamw(learning_rate=optax.constant_schedule(1e-5)) 20 116280.8 MiB 0.0 MiB 10 try: 21 116280.8 MiB 2.4 MiB 10 state = train_state.TrainState.create(apply_fn=unet.__call__, params=params, tx=adamw) 22 print(f"Never goes here :( {state=}") 23 116280.8 MiB 0.0 MiB 10 except jaxlib.xla_extension.XlaRuntimeError as e: 24 # RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) 25 116280.8 MiB -0.1 MiB 10 print(e) ```
PawKanarek commented 10 months ago

The content of this comment was irrelevant to the discussion. I edited it for clarity

old comment Well, the profiling with `jax.profiler.save_device_memory_profile` gives me almost nothing. I created 3 memory profiles: on startup, after creating unet, after crash. ```python import os import jax.lib import jaxlib.xla_extension import optax from flax.training import train_state from diffusers import FlaxUNet2DConditionModel def main(): jax.profiler.save_device_memory_profile(os.path.join("1.prof")) unet, params = FlaxUNet2DConditionModel.from_pretrained( "pcuenq/stable-diffusion-xl-base-1.0-flax", subfolder="unet", ) jax.profiler.save_device_memory_profile(os.path.join("2.prof")) adamw = optax.adamw(learning_rate=optax.constant_schedule(1e-5)) try: state = train_state.TrainState.create(apply_fn=unet.__call__, params=params, tx=adamw) print(f"Never goes here :( {state=}") except jaxlib.xla_extension.XlaRuntimeError as e: # RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0) print(e) jax.profiler.save_device_memory_profile(os.path.join("3.prof")) if __name__ == "__main__": main() ``` And the visualizations with help of `pprof` of these steps are as follows: ​ # 1.prof ![profile001](https://github.com/huggingface/diffusers/assets/15096514/aceb4f8b-abe2-45f5-b71c-504febf24cd7) # 2.prof: ![profile002](https://github.com/huggingface/diffusers/assets/15096514/0884b144-2dab-4d1b-bcd0-b7665a746815) # 3.prof ![profile003](https://github.com/huggingface/diffusers/assets/15096514/accb0c2a-1231-4458-830a-5a108a1f261b) Now i know that, profiling don't include attempts of crashed allocations. Next thing that i will try to do is to get rid of this warnings that are generated during loading phase of `unet` ``` The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'class_embed_type': None, 'class_embeddings_concat': False, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'cross_attention_norm': None, 'downsample_padding': 1, 'dual_cross_attention': False, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'mid_block_only_cross_attention': None, 'mid_block_scale_factor': 1, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None, 'resnet_out_scale_factor': 1.0, 'resnet_skip_time_act': False, 'resnet_time_scale_shift': 'default', 'time_cond_proj_dim': None, 'time_embedding_act_fn': None, 'time_embedding_dim': None, 'time_embedding_type': 'positional', 'timestep_post_act': None, 'upcast_attention': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. ```
PawKanarek commented 10 months ago

The tensorboard profiling is much more useful! Take a look on this tensorboard memory_profile page: image I think that i have a lot of memory overall on entire v3-8 machine, but looks like single TPU has only 16 GB of High Bandwidth Memory (source). This explains why I could allocate more memory in my previous examples in total, but the maximum of a single operation will be capped to this 15.48GiB. That's my naive explanation to this issue.

In summary, now I think that i have to instruct Flax to allocate TrainState on many TPUs simultaneously somehow. I wonder if it's even possible. I will dig deeper into that topic.

Cc: @pcuenca

PawKanarek commented 10 months ago

For now my workaround to get past RESOURCE_EXHAUSTED is to convert weights to fp16

    if weight_dtype == jnp.float16:
        print("converting weights to fp16")
        unet_params = unet.to_fp16(unet_params)
        vae_params = vae.to_fp16(vae_params)
    elif weight_dtype == jnp.bfloat16:
        print("converting weights to bf16")
        unet_params = unet.bf_16(unet_params)
        vae_params = vae.bf_16(vae_params)

With this i can finally create TrainState, but imho it might backfire later.

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

PawKanarek commented 9 months ago

It's been a long month - so long that even bot noticed that the problem was outdated! I would like to give you all update on my progress :)

I've created a script named train_text_to_image_flax_sdxl.py within my forked diffusers repo. Here are few important notes:

Training pameters:

# training with this parametrers took 151 minutes
python3 train_text_to_image_flax_sdxl.py --pretrained_model_name_or_path='stabilityai/stable-diffusion-xl-base-1.0' --train_data_dir='/mnt/disks/persist/repos/spraix/train_data_1024_best_96/' --resolution=1024 --center_crop --train_batch_size=4 --mixed_precision='bf16' --num_train_epochs=1 --learning_rate=1e-05 --max_grad_norm=1 --output_dir='/mnt/disks/persist/repos/spraix_sdxl_best_96_1/'

Comparission of training vs result:

Training Output
trainingImage OutputImage
label: "5-frame sprite animation of: a horned fire demon with large bloody cleaver, that: is taking damage after hit" label: "8-frame sprite animation of: a horned devil with big lasers, that: is jumping, facing: West"

As can be seen, the model model surely adapted the grayish background from the training data :)

Currently, I'm trying to run the training for 8 epochs, it will take almost 24 hours, consuming over 300 GB of memory. Increasing the batch = 8 leads to training failure due to Out Of Memory exceptions. In my theory I could get rid of the OOM probles if i would do smart sharding and parallelism with jax.pmap. Does anyone have experience with this? cc @sayakpaul

sayakpaul commented 9 months ago

No JAX expert here so I am gonna cc @pcuenca.

But really, thanks for sticking to it and sharing with us your findings!

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.