kohya-ss / sd-scripts

Apache License 2.0
5.04k stars 844 forks source link

Cannot restart training after training tenc 2 AND using fused_backward_pass #1369

Open araleza opened 3 months ago

araleza commented 3 months ago

If you finetune SDXL base with:

--train_text_encoder --learning_rate_te1 1e-10 --learning_rate_te2 1e-10 --fused_backward_pass

Then it will train fine. But if you stop training and restart by training from the e.g. <whatever>-step00001000.safetensors file, you get this error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1280, 1280]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This doesn't happen if you only train te1 and the unet. It also only happens when you use --fused_backward_pass.

Full call stack:

Traceback (most recent call last):
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/./sdxl_train.py", line 944, in <module>
    train(args)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/./sdxl_train.py", line 733, in train
    accelerator.backward(loss)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1905, in backward
    loss.backward(**kwargs)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

(I also mentioned this bug back when the original pull request occurred: see https://github.com/kohya-ss/sd-scripts/pull/1259)

kohya-ss commented 3 months ago

Thank you for opening this! Unfortunately, I cannot reproduce this issue. I think it may be caused by the difference of the version of PyTorch. Which version are you using? I'm using 2.1.2.

araleza commented 3 months ago

Sorry, I should have provided reproduction steps for this. Here they are now.

The bug reproduces even when on a clean checkout of the dev branch. (I haven't used main recently).

git clone https://github.com/kohya-ss/sd-scripts
cd sd-scripts
git switch dev
python -m venv venv
source venv/bin/activate
pip install torch torchvision -r requirements.txt
pip install xformers

(I have to install xformers with a second pip command rather than adding it to the parameters of the first 'pip install' line, due to version incompatibilities with torch. I still get torchvision 0.18.1 requires torch==2.3.1, but you have torch 2.3.0 which is incompatible., but this doesn't seem to be an issue when actually running.)

You asked about my torch version. Here it is (from pip list), along with my xformers version:

torch                     2.3.0
xformers                  0.0.26.post1

Then I run training: accelerate launch --num_cpu_threads_per_process=2 "./sdxl_train.py" --pretrained_model_name_or_path="/home/ara/Documents/sdxl/sd_xl_base_1.0.safetensors" --enable_bucket --min_bucket_reso=64 --max_bucket_reso=1024 --train_data_dir="/home/ara/Documents/sdxl/img" --resolution="1024,1024" --output_dir="/home/ara/Documents/sdxl/dreambooth" --logging_dir="/home/ara/Documents/sdxl/log" --save_model_as=safetensors --vae="/home/ara/Documents/sdxl/sdxl_vae.safetensors" --output_name="earthscape" --lr_scheduler_num_cycles="20000" --max_token_length=150 --max_data_loader_n_workers="0" --lr_scheduler="constant_with_warmup" --lr_warmup_steps="200" --max_train_steps="16000" --caption_extension=".txt" --optimizer_type="Adafactor" --optimizer_args scale_parameter=False relative_step=False warmup_init=False --max_data_loader_n_workers="0" --max_token_length=150 --bucket_reso_steps=32 --save_every_n_steps="10" --save_last_n_steps="20" --min_snr_gamma=5 --gradient_checkpointing --xformers --bucket_no_upscale --noise_offset=0.0357 --sample_sampler=k_dpm_2 --fused_backward_pass --cache_latents --train_batch_size="4" --train_text_encoder --learning_rate_te1 1e-10 --learning_rate="2e-7" I set up this training to write a .safetensors file almost immediately, at step 10. After step 10, I stop training, and run the same command line again, but this time changing sd_xl_base_1.0.safetensors to dreambooth/earthscape-step00000010.safetensors. This produces the error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1280, 1280]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This should hopefully reproduce the issue for you. Thank you for your attention so far. :)

If I re-run these steps, but instead of --fused_basedwards_pass I use --full_bf16 --mixed_precision="bf16", then training restarts with no error message.

If I re-run the steps, keeping --fused_backwards_pass, but this time changing --learning_rate_te2 1e-10 to be --learning_rate_te2 0 (stopping tenc 2 from training) then the training process is again able to restart with no error message. So the error only occurs with both the fused_backwards_pass being enabled, and tenc 2 being trained.

This seems to be an important bug to fix for SDXL training, as I am seeing amazing results from training tenc 2 at a very low rate of 1e-10. This training rate has not been possible with bf16 training as it is below the precision that bf16 is able to handle. But with the fp32 training made possible with fused_backwards_pass and tenc 2 being trained, I see impressive image quality changes. I just cannot restart training if I stop!

By the way, I saw a new warning now that I've reinstalled sd-scripts to make these reproduction instructions:

sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,

I didn't have that warning before, so I don't think it's relevant to this bug report.

araleza commented 3 months ago

Update to my long reproduction steps above:

If I remove --save_model_as=safetensors then training is able to restart! (If I run without that option from the start)

So to reproduce the issue, all three of these options need to be set:

--save_model_as=safetensors
--fused_backward_pass
--learning_rate_te2 1e-10

Being able to use .ckpt instead of .safetensors to allow me to continue training is great news, as it provides a workaround way to restart training even without this bug being fixed.

Edit: I just got the error message again, even with .ckpt being used. :-/ Not sure why it worked for that one test run, but it seems that the bug does not need safetensors after all.

kohya-ss commented 3 months ago

Thank you for the detailed steps! The dev branch recommend torch==2.1.2 and xformers==0.0.23.post1 as wrote in README.md. So I may need a new venv to reproduce the issue.

In addition, I don't think the format of the file (.ckpt or .safetensors) affect the issue. So the issue may depend on something special...

araleza commented 3 months ago

Okay, so I took my build and installed those versions:

pip install torch==2.1.2 xformers==0.0.23.post1 torchvision

which got me:

torch                     2.1.2
xformers                  0.0.23.post1

(I had to include torchvision on that installation line to get one that worked with torch 2.1.2)

I made sure to regenerate a fresh .ckpt file, and didn't pick up the one that I'd already made with the later torch version that I previously had. But, the same error message still reproduces, even with a new .ckpt being written out by sd-scripts, and torch/xformers set to these older versions.

araleza commented 3 months ago

I started trying to get more information about this. Since the error message suggested that I add torch.autograd.set_detect_anomaly(True), I did that, and got the following debugging trace information associated with the error, if this rings any bells for anyone:

warnings.warn( /home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error: File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 271, in backward outputs = ctx.run_function(detached_inputs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 372, in forward hidden_states, attn_weights = self.self_attn( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 262, in forward key_states = self._shape(self.k_proj(hidden_states), -1, bsz) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass /home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning:

Previous calculation was induced by CheckpointFunctionBackward. Traceback of forward call that induced the previous calculation: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in cli.main() File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="main") File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "sdxl_train.py", line 963, in train(args) File "sdxl_train.py", line 659, in train encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( File "/home/ara/m.2/Dev/sdxl/sd-scripts/library/train_util.py", line 4701, in get_hidden_states_sdxl enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 1207, in forward text_outputs = self.text_model( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 703, in forward encoder_outputs = self.encoder( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 622, in forward layer_outputs = self._gradient_checkpointing_func( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn return fn(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(args, kwargs) # type: ignore[misc] (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:121.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass /home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in CheckpointFunctionBackward. Traceback of forward call that caused the error: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in cli.main() File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="main") File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "sdxl_train.py", line 963, in train(args) File "sdxl_train.py", line 659, in train encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( File "/home/ara/m.2/Dev/sdxl/sd-scripts/library/train_util.py", line 4701, in get_hidden_states_sdxl enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 1207, in forward text_outputs = self.text_model( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 703, in forward encoder_outputs = self.encoder( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 622, in forward layer_outputs = self._gradient_checkpointing_func( File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, *kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn return fn(args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, kwargs) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(args, kwargs) # type: ignore[misc] (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Tenc2 is mentioned in that trace:

enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)

Any thoughts?

araleza commented 3 months ago

Okay, I haven't fix it, but I've now found a workaround that allows training of tenc2 to continue, and it also indicates roughly where the trouble is likely to be coming from.

The workaround is to edit the torch library file: venv/lib/python3.10/site-packages/torch/utils/checkpoint.py which is in that subdirectory of your sd-scripts checkout, assuming you're using a venv (which you probably should be). Add the line indicated here:

image

and you can successfully continue training a .safetensors or .ckpt that was written out by sd-scripts while training.

The issue seems to be related this this warning, which is seen even when using the recommended torch version 2.1.2:

sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:430: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

I'll let someone that actually knows what they're doing with this figure out the correct fix. And I'm still not sure why this issue doesn't trigger when starting training from sd_xl_base_1.0.safetensors , but it does trigger when starting from a .safetensors written out by sd-scripts.

Jannchie commented 3 months ago

I encountered a similar issue. When enabling optimizer_type = "AdamW" and train_text_encoder, a warning about use_reentrant appears. At this point, if I try to save the training state, the training process freezes.

Jannchie commented 3 months ago

Add {"use_reentrant": True} to sdxl_train.py could fix the use_reentrant problem.. But it still freezes on saved states. May be it is another issue.

if args.gradient_checkpointing:
    text_encoder1.gradient_checkpointing_enable({"use_reentrant": False})
    text_encoder2.gradient_checkpointing_enable({"use_reentrant": False})
araleza commented 3 months ago

@Jannchie, I didn't know use_reentrant could be passed into gradient_checkpointing_enable() like that. That's great news, as it lets my issue be fixed with a sd-scripts change, rather than hacking the library function like I was doing.

As for your hang, is it specific to AdamW? I've only been using Adafactor. Is there some advantage to using AdamW by the way? I haven't tried that.

Jannchie commented 3 months ago

I’m a beginner and not quite sure about the specific effects, but I am attempting to replicate the settings from https://huggingface.co/cagliostrolab/animagine-xl-3.1.

Regarding the freeze issue, by referring to this issue, I found that specifying the saving format as safetensors (instead of the default diffusers format) can resolve the problem.

TopSalad3530 commented 3 months ago

I've ran into the same problem. In my case, the issue only started to appear after I switched to save_precision=float from FP16. Converting the problematic checkpoint down to FP16 seemed to have resolved the issue for me. I don't think this necessarily has anything to do with the precision itself however: it might just be that the conversion process happened to have cleared whatever problematic metadata ("version") from the tensors as a side effect.

I did also try the {"use_reentrant": False} option, but for some reason it increased VRAM consumption so much that training was no longer possible on 24GB, even at batch size 1, so I don't believe it's a one-size-fit-all solution to this problem. Mistake. See below.

araleza commented 3 months ago

Interesting. I'm also on 24GB, and batch size 4 works great for me. Have you

  1. passed in --gradient_checkpointing, which (if I forget to pass it in) is usually what makes me run out of memory unexpectedly.
  2. passed in --cache_latents, so it doesn't have to run the VAE repeatedly? I keep features like color augmentation / random crop switched off to allow this, but you can keep flip augmentation on.
TopSalad3530 commented 3 months ago

My bad -- turned out that I based my modifications on the sdxl_train.py from main which didn't have fused_backward_pass at all, instead of dev. Tried again and this time everything went fine.