kohya-ss / sd-scripts

Apache License 2.0
4.96k stars 833 forks source link

Stable Cascade (Würstchen) Support #1119

Open alfredplpl opened 7 months ago

alfredplpl commented 7 months ago

Stable Cascade (Würstchen) is released. https://ja.stability.ai/blog/stable-cascade

So, I want to finetune it by this script. You have implemented it on the Würstchen branch.

The branch is support for Stable Cascade?

alfredplpl commented 7 months ago

FYI: https://github.com/Stability-AI/StableCascade/tree/master/train

dill-shower commented 7 months ago

+1

ashrafbay commented 7 months ago

+1

rdcoder33 commented 7 months ago

+1

kohya-ss commented 7 months ago

I'm working on stable-cascade branch😀

alfredplpl commented 7 months ago

Awesome. I can full-finetune the stage c.

test

alfredplpl commented 7 months ago

btw i got the error:

2024-02-18 11:32:04 INFO     use 8-bit Lion optimizer | {}                                               train_util.py:3563
enable full bf16 training.
running training / 学習開始
  num examples / サンプル数: 390503
  num batches per epoch / 1epochのバッチ数: 97639
  num epochs / epoch数: 1
  batch size per device / バッチサイズ: 4
  gradient accumulation steps / 勾配を合計するステップ数 = 16
  total optimization steps / 学習ステップ数: 1600
steps:   0%|                                                                                      | 0/1600 [00:00<?, ?it/s]
epoch 1/1
Traceback (most recent call last):
  File "/mnt/my_raid/github/sd-scripts/stable_cascade_train_stage_c.py", line 526, in <module>
    train(args)
  File "/mnt/my_raid/github/sd-scripts/stable_cascade_train_stage_c.py", line 442, in train
    current_loss = loss.detach().item()  # 平均なのでbatch sizeは関係ないはず
RuntimeError: a Tensor with 4 elements cannot be converted to Scalar
steps:   0%|                                                                                      | 0/1600 [00:03<?, ?it/s]

i think we can fix this error by current_loss = loss.detach().mean().item() on line 442 of stable_cascade_train_stage_c.py. what do you think?

alfredplpl commented 7 months ago

and i got the error:

steps:   2%|█                                                    | 1000/48816 [2:19:08<110:53:03,  8.35s/it, avr_loss=0.21]Traceback (most recent call last):
  File "/mnt/my_raid/github/sd-scripts/stable_cascade_train_stage_c.py", line 529, in <module>
    train(args)
  File "/mnt/my_raid/github/sd-scripts/stable_cascade_train_stage_c.py", line 442, in train
    accelerator.accelerator.unwrap_model(stage_c),
AttributeError: 'Accelerator' object has no attribute 'accelerator'
steps:   2%|█                                                    | 1000/48816 [2:19:08<110:53:22,  8.35s/it, avr_loss=0.21]

accelerator.accelerator.unwrap_model(stage_c) -> accelerator.unwrap_model(stage_c) ?

kohya-ss commented 7 months ago

Thank you for testing! The batch size issue and accelerator.accelerator issue should be solved now.

dill-shower commented 7 months ago

How many VRAM is required?

kohya-ss commented 7 months ago

How many VRAM is required?

16GB VRAM seems to be required with the batch size 1, mixed precision bf16 and AdaFactor optimizer with relative_step=False.

optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
ashrafbay commented 7 months ago

and lite ؟

tetsuoo-online commented 7 months ago

Oh no I'm doomed, I have 12GB. Maybe fp16... :/

How many VRAM is required?

16GB VRAM seems to be required with the batch size 1, mixed precision bf16 and AdaFactor optimizer with relative_step=False.

terrificdm commented 7 months ago

@alfredplpl How about your training result of SC comparing with SDXL? I tried with SC training branch, it turned out that the color and the brightness was not as good as SDXL, everything was dull...

alfredplpl commented 7 months ago

@terrificdm I trained SC on an anime dataset with 30000 steps and 5e-4 as a the learning rate. The result is as follows:

before training:

ComfyUI_temp_fesqj_00014_

after training:

ComfyUI_temp_fesqj_00010_

prompt: 1girl, white hair, blue eyes, white T-shirt saying "Hello", yellow background

negative prompt: photo, bad hands, bad anatomy, bad pupil

terrificdm commented 7 months ago

@terrificdm I trained SC on an anime dataset with 30000 steps and 5e-4 as a the learning rate. The result is as follows:

Thanks, seems nice, and I will try later.

alfredplpl commented 7 months ago

@terrificdm Sorry. the learning rate is 5e-6.

2kpr commented 6 months ago

@kohya-ss , Just wanted you to be aware of a known issue (per the Stable Cascade devs themselves - see image attached) with 'loading' the models in bfloat16 vs 'loading' the models in float32 when training. I say 'loading' because even when 'loaded' in float32 there is that small section of the forward pass done in bfloat16 via torch amp as you know.

scbf16vsfp32

As it mentions in the attached image the issue is that when the models are 'loaded' in float32 Stable Cascade trains much faster and better, but when 'loaded' in bfloat16 Stable Cascade doesn't seem to train at all or very slowly, which then forces you to increase the learning rate, and if you increase that too much then you can more easily overtrain the model.

kohya-ss commented 6 months ago

@2kpr Thank you for letting me know.

If mixed precision is enabled, the script loads the weights in bfloat16. With some brief testing, it seems to be same as full_bf16. It is not intended.

It will be better to load the model without type casting, and cast it to bf16 if full_bf16 is specified.

thojmr commented 6 months ago

@kohya-ss Will there be support for training stage_c_lite as well? It has a slightly different data shape and doesn't want to be loaded during training.

Traceback (most recent call last):
  File "/root/data/stable_cascade_train_stage_c.py", line 564, in <module>
    train(args)
  File "/root/data/stable_cascade_train_stage_c.py", line 140, in train
    stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
  File "/root/data/library/stable_cascade_utils.py", line 110, in load_stage_c_model
    info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
  File "/root/data/library/sdxl_model_util.py", line 163, in _load_state_dict_on_device

RuntimeError: Error(s) in loading state_dict for StageC:
        Missing key(s) in state_dict: "up_blocks.0.41.attention.attn.to_out.weight", "down_blocks.1.62.attention.attn.to_out.bias", "up_blocks.0.60.channelwise.2.beta", "up_blocks.0.41.attention.attn.to_v.bias",
...hundreds more listed

stage_c_lite.safetensors: https://huggingface.co/stabilityai/stable-cascade/tree/main