kohya-ss / sd-scripts

Apache License 2.0
5.19k stars 862 forks source link

Flux Dedistilled / fluxdev2pro support ? #1702

Open Tablaski opened 2 weeks ago

Tablaski commented 2 weeks ago

I'm trying to use the amazing new Dedistilled models with the trainer

If you haven't tried them, they are groundbreaking : https://civitai.com/models/843551 For me it's the biggest thing in the Flux community since we're able to train LoRas.

They would allow training with CFG > 1 (Guidance > 1) thus probably allowing much better caption adherence during training / possible better prompt adherence later on

(Although it is not sure we can properly use a LoRa trained with CFG > 1 with distilled models. But if we can it would probably be amazing, that is why we need to try ASAP)

Currently I've just tried to replace flux1dev.sft by another file in the following parameter

--pretrained_model_name_or_path "C:\fluxgym\models\unet\flux1-dev.sft"

But I got this error which I haven't really investigated yet. I have the same using fluxdev2pro which is a fine-tuned dedistilled model enhancing training :

File "C:\fluxgym\sd-scripts\flux_train_network.py", line 519, in trainer.train(args) File "C:\fluxgym\sd-scripts\train_network.py", line 354, in train model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\fluxgym\sd-scripts\flux_train_network.py", line 82, in load_target_model model = self.prepare_split_model(model, weight_dtype, accelerator) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\fluxgym\sd-scripts\flux_train_network.py", line 127, in prepare_split_model flux_upper.to(accelerator.device, dtype=target_dtype) File "C:\fluxgym\env\Lib\site-packages\torch\nn\modules\module.py", line 1340, in to return self._apply(convert) ^^^^^^^^^^^^^^^^^^^^ File "C:\fluxgym\env\Lib\site-packages\torch\nn\modules\module.py", line 900, in _apply module._apply(fn) File "C:\fluxgym\env\Lib\site-packages\torch\nn\modules\module.py", line 900, in _apply module._apply(fn) File "C:\fluxgym\env\Lib\site-packages\torch\nn\modules\module.py", line 927, in _apply param_applied = fn(param) ^^^^^^^^^ File "C:\fluxgym\env\Lib\site-packages\torch\nn\modules\module.py", line 1333, in convert raise NotImplementedError( NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device. Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in run_code File "C:\fluxgym\env\Scripts\accelerate.exe_main.py", line 7, in File "C:\fluxgym\env\Lib\site-packages\accelerate\commands\accelerate_cli.py", line 48, in main args.func(args) File "C:\fluxgym\env\Lib\site-packages\accelerate\commands\launch.py", line 1174, in launch_command simple_launcher(args) File "C:\fluxgym\env\Lib\site-packages\accelerate\commands\launch.py", line 769, in simple_launcher raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) subprocess.CalledProcessError:

Ice-YY commented 2 weeks ago

Maybe you should try to pull the latest version? It works okay on my computer with the dedistilled model.

Tablaski commented 2 weeks ago

Really ? that's great. May I asked what differences have you noticed training stuff with it ?

How much CFG did you set ? Did it worked on distilled models with CFG = 1 afterwards ?

How were your results using the LoRa, both on distilled and dedistilled ?

Ice-YY commented 2 weeks ago

You could check this disscusion: https://huggingface.co/nyanko7/flux-dev-de-distill/discussions/3 Dsienra is conducting some tests on this model and posting feedback.

Tablaski commented 2 weeks ago

@Ice-YY thank you, I didn't see that discussion on nyanko7's hugginface

This is extremely interesting. Still, have you tried yourself ?

Ice-YY commented 2 weeks ago

I've been training a LoRA model on a dataset with several distinct art styles, each with its own unique trigger word. When training on the base model, the output doesn’t change much regardless of which trigger word I use for the different art styles. However, when I train the LoRA model using the de-distilled base model, it does show some ability to differentiate between the trigger words, although the results are still not ideal when compared to training LoRA on SDXL. For now I'm experimenting with training using a larger (like 6.0) guidance scales to see if I can improve the results.

Tablaski commented 2 weeks ago

Ok so you've tried with guidance 1 using dedistilled for the moment ? Then did you generate images with it back with distilled ?

I am very curious to know if a LoRa trained on dedistilled with a guidance >= 4 (I mean not just 1.5 or 2) would work with distilled flux, meaning it is backward compatible

Ice-YY commented 2 weeks ago

Ok so you've tried with guidance 1 using dedistilled for the moment ? Then did you generate images with it back with distilled ?

Yes. And now I can confirm that LoRa trained on distilled guidance 6.0 works fine with the distilled Flux.

Tophness commented 2 weeks ago

I've had incredible early success with this model in combination with ademix8bit. In 326 steps it achieved the bulk of what it took 40,000 to get to on distilled adamw8bit. I used cfg of 1 and lr 1e-4 for training, then default cfg in forge for inference.

Tablaski commented 2 weeks ago

This very good news then, I get that training with de-distilled + guidance >= 1 improves prompt adherence when back with distilled models and that they're able to use Distilled guidance as usual. For LoRas at least (i'm currently asking questions to someone who has just finetuned a checkpoint, will update here).

Tophness commented 1 week ago

I changed line 152 in library/flux_train_utils: scale = prompt_dict.get("scale", 1.0) but the images are still completely distorted. I'm guessing the problem is the model itself. Maybe you'd need to switch to a non-distilled model for inference.

Sarania commented 6 days ago

Just thought to add my experience: On the latest pull of the SD3 branch, I can train on DevDedistilled. It's detected as a Schnell model because of the missing guidance blocks(line 77 in library/flux_utils.py): 'is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) and initialized as such. Training works but that's possibly why samples are broken? Forcing it to be detected as a dev model breaks it which makes sense. I imagine it might need some special consideration for full support. Unlike Dev2Pro, DevDedistilled fully removes the distilled guidance blocks.