Closed PawKanarek closed 8 months ago
Hi @PawKanarek
our pytorch training examples are intended to run on GPU not TPUs. cc @sayakpaul to confirm
Thank you @yiyixuxu. This would explain why I'm struggling so hard with it. I will try now to train stable diffusion v1-4, since it's mentined that it can be done on TPU. I found more information about training with Flax/Jax here in this README.md.
I'm keeping this ticket open because perhaps @sayakpaul can provide some hints on what I need to do, to train SDXL on TPUs.
For maximizing the performance on TPUs, I welcome you to check our JAX training scripts.
Thanks @sayakpaul for your response and for directing me to the JAX training scripts, studying them gave me strength and knowledge to work on new script for training SDXL with flax called train_text_to_image_flax_sdxl.py
:)
However, after after some bigger and smaller problems that i could solve on my own, now I'm stuck.
I'm encountering a problem where I cannot create a TrainState
for the UNet
of SDXL
. I'm getting a strange runtime error XlaRuntimeError
indicating that program cannot allocate 50MB, even though there's still about 310GB available on the machine (Google TPU v3-8)
The error says:
# RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0)
Im providing a minimal functioning code that is needed to reproduce this issue:
import jax.numpy as jnp
import optax
from flax.training import train_state
from diffusers import FlaxUNet2DConditionModel
print("initialize sdxl-base FlaxUNet2DConditionModel")
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
"pcuenq/stable-diffusion-xl-base-1.0-flax",
from_pt=False,
subfolder="unet",
dtype=jnp.bfloat16,
)
constant_scheduler = optax.constant_schedule(1e-5)
adamw = optax.adamw(learning_rate=constant_scheduler, b1=0.9, b2=0.999, eps=1e-08, weight_decay=1e-2)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), adamw)
# I cannot create TrainState for SDXL on TPU v3-8 with 335 GB RAM
# It states that it have problem with allocating 50M of memory, despite fact that ~310GB is still free.
# to this point program will allocate about ~20GB of memory, and will throw error RESOURCE_EXHAUSTED
print("******* TRYING TO CREATE TrainState AND THIS WILL THROW RESOURCE_EXHAUSTED")
state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
# throws: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 50.00M. That was not possible. There are 27.86M free.; (0x0x0_HBM0)
print(f"Never goes here :( {state=}")
The full code to generate this error is in my new script:603-616, but this minimal code snippet is enough to create this runtime error.
Why I'm thinking this is weird error? Because i know that i used up to 60GB
or RAM more on this machine and i did not recieve any errors, and now I'm getting RESOURCE_EXHAUSTED
when only 20GB is used.
Any idea or guidance in solving this problem would be greatly appreciated. Thank you!
Cc: @pcuenca
The content of this comment was irrelevant to the discussion. I edited it for clarity
The content of this comment was irrelevant to the discussion. I edited it for clarity
The tensorboard profiling is much more useful! Take a look on this tensorboard memory_profile
page:
I think that i have a lot of memory overall on entire v3-8 machine, but looks like single TPU has only 16 GB
of High Bandwidth Memory (source). This explains why I could allocate more memory in my previous examples in total, but the maximum of a single operation will be capped to this 15.48GiB
. That's my naive explanation to this issue.
In summary, now I think that i have to instruct Flax to allocate TrainState
on many TPUs simultaneously somehow. I wonder if it's even possible. I will dig deeper into that topic.
Cc: @pcuenca
For now my workaround to get past RESOURCE_EXHAUSTED
is to convert weights to fp16
if weight_dtype == jnp.float16:
print("converting weights to fp16")
unet_params = unet.to_fp16(unet_params)
vae_params = vae.to_fp16(vae_params)
elif weight_dtype == jnp.bfloat16:
print("converting weights to bf16")
unet_params = unet.bf_16(unet_params)
vae_params = vae.bf_16(vae_params)
With this i can finally create TrainState
, but imho it might backfire later.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
It's been a long month - so long that even bot noticed that the problem was outdated! I would like to give you all update on my progress :)
I've created a script named train_text_to_image_flax_sdxl.py within my forked diffusers
repo. Here are few important notes:
jax.pmap
parallelism as I was constantly fighting with Out Of Memory
exceptions1 epoch
with 96 images
and batch = 4
takes about 151 minutes
.Training pameters:
# training with this parametrers took 151 minutes
python3 train_text_to_image_flax_sdxl.py --pretrained_model_name_or_path='stabilityai/stable-diffusion-xl-base-1.0' --train_data_dir='/mnt/disks/persist/repos/spraix/train_data_1024_best_96/' --resolution=1024 --center_crop --train_batch_size=4 --mixed_precision='bf16' --num_train_epochs=1 --learning_rate=1e-05 --max_grad_norm=1 --output_dir='/mnt/disks/persist/repos/spraix_sdxl_best_96_1/'
Comparission of training vs result:
Training | Output |
---|---|
label: "5-frame sprite animation of: a horned fire demon with large bloody cleaver, that: is taking damage after hit" | label: "8-frame sprite animation of: a horned devil with big lasers, that: is jumping, facing: West" |
As can be seen, the model model surely adapted the grayish background from the training data :)
Currently, I'm trying to run the training for 8 epochs
, it will take almost 24 hours, consuming over 300 GB of memory.
Increasing the batch = 8
leads to training failure due to Out Of Memory exceptions. In my theory I could get rid of the OOM probles if i would do smart sharding and parallelism with jax.pmap
. Does anyone have experience with this? cc @sayakpaul
No JAX expert here so I am gonna cc @pcuenca.
But really, thanks for sticking to it and sharing with us your findings!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Describe the bug
Hello 🤗. I'm trying to train a LoRA SDXL on Google Cloud TPU v3-8 machine with this script: train_text_to_image_lora_sdxl.py. I followed setup guide from diffusers/examples/text_to_image/README_sdxl.md but im stuck at very beginning.
I also attempted to run this script on my local Apple M1 Pro. Despite the mps device not being supported, it appears that the script successfully bypasses the error encountered on the Google Cloud TPU, which I mention in this ticket.
Stacktrace suggest that
accelerate
cannot launchPrepareForLaunch
function. Details below. ⬇️Any help and ideas to move forward would be appreciated <3
Reproduction
Run that CLI commands in
diffusers/examples/text_to_image/
directory. Those commands are copy pasted from README_sdxl.md, for simplicity i deleted parameters--report_to="wandb"
and--push_to_hub
Logs
System Info
As suggested in README_lora.md im using diffusers from latest main (commit sha: c896b841e48b65e561800f829c546f4cf047e634)
Output of
$ diffusers-cli env
Accelerate config file:
This is brand new, fresh and clean conda environment with packages listed below. Output of
$ conda list
Who can help?
@sayakpaul