segmind / SSD-1B

SSD-1B, an open-source text-to-image model, outperforming previous versions by being 50% smaller and 60% faster than SDXL.
Apache License 2.0
166 stars 14 forks source link

Error message about Query/Key/Value #6

Open razvanab opened 11 months ago

razvanab commented 11 months ago

I get this error when i try to train on SSD-1B:


A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
11/25/2023 23:13:08 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'variance_type', 'dynamic_thresholding_ratio', 'clip_sample_range', 'thresholding'} was not found in config. Values will be initialized to default values.
{'dropout', 'attention_type'} was not found in config. Values will be initialized to default values.
11/25/2023 23:13:22 - INFO - __main__ - ***** Running training *****
11/25/2023 23:13:22 - INFO - __main__ -   Num examples = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num batches each epoch = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num Epochs = 42
11/25/2023 23:13:22 - INFO - __main__ -   Instantaneous batch size per device = 1
11/25/2023 23:13:22 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
11/25/2023 23:13:22 - INFO - __main__ -   Gradient Accumulation steps = 2
11/25/2023 23:13:22 - INFO - __main__ -   Total optimization steps = 500
Steps:   0%|                                                                                   | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1732, in <module>
    main(args)
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1451, in main
    model_pred = unet(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\amp\autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_condition.py", line 1075, in forward
    sample, res_samples = downsample_block(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_blocks.py", line 1150, in forward
    hidden_states = attn(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\transformer_2d.py", line 363, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\utils\checkpoint.py", line 458, in checkpoint
    ret = function(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention.py", line 293, in forward
    attn_output = self.attn2(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 522, in forward
    return self.processor(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 1144, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 223, in memory_efficient_attention
    return _memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 326, in _memory_efficient_attention
    return _fMHA.apply(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\autograd\function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 42, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 348, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\common.py", line 121, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
Steps:   0%|                                                                                   | 0/500 [00:15<?, ?it/s]
Traceback (most recent call last):
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "G:\DEV\Conda\envs\temp\Scripts\accelerate.exe\__main__.py", line 7, in <module>
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\accelerate_cli.py", line 47, in main
    args.func(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 994, in launch_command
    simple_launcher(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 636, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['G:\\DEV\\Conda\\envs\\temp\\python.exe', 'train_dreambooth_lora_sdxl.py', '--pretrained_model_name_or_path=segmind/SSD-1B', '--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix', '--instance_data_dir=aaa', '--output_dir=lora-aaa-SSD-1B', '--instance_prompt=a photo of kud0', '--mixed_precision=fp16', '--resolution=1024', '--train_batch_size=1', '--gradient_accumulation_steps=2', '--gradient_checkpointing', '--learning_rate=1e-4', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--validation_prompt=A photo of kud0', '--validation_epochs=300', '--num_validation_images=1', '--seed=0', '--enable_xformers_memory_efficient_attention']' returned non-zero exit status 1.```
razvanab commented 11 months ago

I get this error when i try to train on SSD-1B:

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
11/25/2023 23:13:08 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'variance_type', 'dynamic_thresholding_ratio', 'clip_sample_range', 'thresholding'} was not found in config. Values will be initialized to default values.
{'dropout', 'attention_type'} was not found in config. Values will be initialized to default values.
11/25/2023 23:13:22 - INFO - __main__ - ***** Running training *****
11/25/2023 23:13:22 - INFO - __main__ -   Num examples = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num batches each epoch = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num Epochs = 42
11/25/2023 23:13:22 - INFO - __main__ -   Instantaneous batch size per device = 1
11/25/2023 23:13:22 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
11/25/2023 23:13:22 - INFO - __main__ -   Gradient Accumulation steps = 2
11/25/2023 23:13:22 - INFO - __main__ -   Total optimization steps = 500
Steps:   0%|                                                                                   | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1732, in <module>
    main(args)
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1451, in main
    model_pred = unet(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\amp\autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_condition.py", line 1075, in forward
    sample, res_samples = downsample_block(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_blocks.py", line 1150, in forward
    hidden_states = attn(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\transformer_2d.py", line 363, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\utils\checkpoint.py", line 458, in checkpoint
    ret = function(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention.py", line 293, in forward
    attn_output = self.attn2(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 522, in forward
    return self.processor(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 1144, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 223, in memory_efficient_attention
    return _memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 326, in _memory_efficient_attention
    return _fMHA.apply(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\autograd\function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 42, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 348, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\common.py", line 121, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
Steps:   0%|                                                                                   | 0/500 [00:15<?, ?it/s]
Traceback (most recent call last):
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "G:\DEV\Conda\envs\temp\Scripts\accelerate.exe\__main__.py", line 7, in <module>
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\accelerate_cli.py", line 47, in main
    args.func(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 994, in launch_command
    simple_launcher(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 636, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['G:\\DEV\\Conda\\envs\\temp\\python.exe', 'train_dreambooth_lora_sdxl.py', '--pretrained_model_name_or_path=segmind/SSD-1B', '--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix', '--instance_data_dir=aaa', '--output_dir=lora-aaa-SSD-1B', '--instance_prompt=a photo of kud0', '--mixed_precision=fp16', '--resolution=1024', '--train_batch_size=1', '--gradient_accumulation_steps=2', '--gradient_checkpointing', '--learning_rate=1e-4', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--validation_prompt=A photo of kud0', '--validation_epochs=300', '--num_validation_images=1', '--seed=0', '--enable_xformers_memory_efficient_attention']' returned non-zero exit status 1.```

This is the script and the config: accelerate launch train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path="segmind/SSD-1B" --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" --instance_data_dir="aaa" --output_dir="lora-aaa-SSD-1B" --instance_prompt="a photo of kud0" --mixed_precision="fp16" --resolution=1024 --train_batch_size=1 --gradient_accumulation_steps=2 --gradient_checkpointing --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 --max_train_steps=500 --validation_prompt="A photo of kud0" --validation_epochs=300 --num_validation_images=1 --seed="0" --enable_xformers_memory_efficient_attention

debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
razvanab commented 11 months ago

I get this error when i try to train on SSD-1B:

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
11/25/2023 23:13:08 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'variance_type', 'dynamic_thresholding_ratio', 'clip_sample_range', 'thresholding'} was not found in config. Values will be initialized to default values.
{'dropout', 'attention_type'} was not found in config. Values will be initialized to default values.
11/25/2023 23:13:22 - INFO - __main__ - ***** Running training *****
11/25/2023 23:13:22 - INFO - __main__ -   Num examples = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num batches each epoch = 24
11/25/2023 23:13:22 - INFO - __main__ -   Num Epochs = 42
11/25/2023 23:13:22 - INFO - __main__ -   Instantaneous batch size per device = 1
11/25/2023 23:13:22 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
11/25/2023 23:13:22 - INFO - __main__ -   Gradient Accumulation steps = 2
11/25/2023 23:13:22 - INFO - __main__ -   Total optimization steps = 500
Steps:   0%|                                                                                   | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1732, in <module>
    main(args)
  File "G:\DEV\train_dreambooth_lora_sdxl.py", line 1451, in main
    model_pred = unet(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\utils\operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\amp\autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_condition.py", line 1075, in forward
    sample, res_samples = downsample_block(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\unet_2d_blocks.py", line 1150, in forward
    hidden_states = attn(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\transformer_2d.py", line 363, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\_dynamo\external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\utils\checkpoint.py", line 458, in checkpoint
    ret = function(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention.py", line 293, in forward
    attn_output = self.attn2(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 522, in forward
    return self.processor(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\diffusers\models\attention_processor.py", line 1144, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 223, in memory_efficient_attention
    return _memory_efficient_attention(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 326, in _memory_efficient_attention
    return _fMHA.apply(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\torch\autograd\function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 42, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\__init__.py", line 348, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "G:\DEV\Conda\envs\temp\lib\site-packages\xformers\ops\fmha\common.py", line 121, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
Steps:   0%|                                                                                   | 0/500 [00:15<?, ?it/s]
Traceback (most recent call last):
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "G:\DEV\Conda\envs\temp\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "G:\DEV\Conda\envs\temp\Scripts\accelerate.exe\__main__.py", line 7, in <module>
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\accelerate_cli.py", line 47, in main
    args.func(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 994, in launch_command
    simple_launcher(args)
  File "G:\DEV\Conda\envs\temp\lib\site-packages\accelerate\commands\launch.py", line 636, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['G:\\DEV\\Conda\\envs\\temp\\python.exe', 'train_dreambooth_lora_sdxl.py', '--pretrained_model_name_or_path=segmind/SSD-1B', '--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix', '--instance_data_dir=aaa', '--output_dir=lora-aaa-SSD-1B', '--instance_prompt=a photo of kud0', '--mixed_precision=fp16', '--resolution=1024', '--train_batch_size=1', '--gradient_accumulation_steps=2', '--gradient_checkpointing', '--learning_rate=1e-4', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--validation_prompt=A photo of kud0', '--validation_epochs=300', '--num_validation_images=1', '--seed=0', '--enable_xformers_memory_efficient_attention']' returned non-zero exit status 1.```

This is the script and the config: accelerate launch train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path="segmind/SSD-1B" --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" --instance_data_dir="aaa" --output_dir="lora-aaa-SSD-1B" --instance_prompt="a photo of kud0" --mixed_precision="fp16" --resolution=1024 --train_batch_size=1 --gradient_accumulation_steps=2 --gradient_checkpointing --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 --max_train_steps=500 --validation_prompt="A photo of kud0" --validation_epochs=300 --num_validation_images=1 --seed="0" --enable_xformers_memory_efficient_attention

debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Xformers were the problem. Unfortunately, I can't make 8bit_adam work, so I thought I could use Xformers instead. --enable_xformers_memory_efficient_attention