huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.13k stars 5.19k forks source link

Incomplete Lora support for Flux Dev from SimpleTuner #9270

Open axel578 opened 2 weeks ago

axel578 commented 2 weeks ago

Describe the bug

I trained a lora with simpletuner using ai-toolkit preset (I used all+ffs and others and it doesnt train correctly on hard concepts). And Now I have this issue when loading the lora:

File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load module._load_from_state_dict( File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict deserialized_weight = QBytesTensor.load_from_state_dict( File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict inner_tensors_dict[name] = state_dict.pop(prefix + name) KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

From simpletuner code ai-toolkit layer reference:

elif args.flux_lora_target == "all+ffs":
            target_modules = [
                "to_k",
                "to_q",
                "to_v",
                "add_k_proj",
                "add_q_proj",
                "add_v_proj",
                "to_out.0",
                "to_add_out",
                "ff.net.0.proj",
                "ff.net.2",
                "ff_context.net.0.proj",
                "ff_context.net.2",
                "proj_mlp",
                "proj_out",
            ]
        elif args.flux_lora_target == "ai-toolkit":
            # from ostris' ai-toolkit, possibly required to continue finetuning one.
            target_modules = [
                "to_q",
                "to_k",
                "to_v",
                "add_q_proj",
                "add_k_proj",
                "add_v_proj",
                "to_out.0",
                "to_add_out",
                "ff.net.0.proj",
                "ff.net.2",
                "ff_context.net.0.proj",
                "ff_context.net.2",
                "norm.linear",
                "norm1.linear",
                "norm1_context.linear",
                "proj_mlp",
                "proj_out",
            ]```

### Reproduction

here is the python code:
```python
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from diffusers.utils import load_image
import cv2
import numpy as np
from PIL import Image
import logging
from optimum.quanto import freeze, qfloat8, quantize
import sys
import safetensors.torch

# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Helper function to apply Canny edge detection
def apply_canny(image, low_threshold=100, high_threshold=200):
    image = np.array(image)
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)

# Load the main model
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.info("Loading transformer...")
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
transformer.to(device)

logging.info("Loading text encoder...")
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
text_encoder_2.to(device)

logging.info("Creating pipeline...")
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to(device)

logging.info("Pipeline created")

# Load and apply multiple LoRAs
loras = [
    {"path": "v1.safetensors", "alpha": 0.5},
    {"path": "v2.safetensors", "alpha": 0.3},
]
pipe.load_lora_weights("./v1",weight_name="v1.safetensors",
                       adapter_name="default",
                     cross_attention_kwargs={"scale": 0.5})
pipe.load_lora_weights("./v2",weight_name="v2.safetensors",
                       adapter_name="default",
                     cross_attention_kwargs={"scale": 0.5})
logging.info("Loading LoRAs...")

# Prepare the input image and apply Canny edge detection
try:
    input_image = load_image("img.png")
    control_image = apply_canny(input_image)
    logging.info("Input image processed successfully")
except Exception as e:
    logging.error(f"Error processing input image: {str(e)}")
    sys.exit(1)

# Generate the image

prompt = "A fire inventory"
logging.info("Starting image generation...")
try:
    with torch.no_grad():
        image = pipe(
            prompt,
            num_inference_steps=30,
            guidance_scale=7.5,
            generator=torch.Generator(device=device).manual_seed(0)
        ).images[0]
    logging.info("Image generation completed")
except Exception as e:
    logging.error(f"Error during image generation: {str(e)}")
    logging.error(f"Error details: {str(e)}", exc_info=True)
    sys.exit(1)```

### Logs

```shell
Traceback (most recent call last):
  File "/home/axel/Documents/Flux/infc.py", line 82, in <module>
    pipe.load_lora_weights("./v1",weight_name="v1.safetensors",
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1647, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1736, in load_lora_into_transformer
    incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 395, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

System Info

Latest diffusers cloned 2 hours ago, Linux debian.

Who can help?

@saya

bghira commented 2 weeks ago

it's because of Quanto. if you don't use that, this will work. there is work being done on that.