I trained a lora with simpletuner using ai-toolkit preset (I used all+ffs and others and it doesnt train correctly on hard concepts).
And Now I have this issue when loading the lora:
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
module._load_from_state_dict(
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
deserialized_weight = QBytesTensor.load_from_state_dict(
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict
inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'
From simpletuner code ai-toolkit layer reference:
elif args.flux_lora_target == "all+ffs":
target_modules = [
"to_k",
"to_q",
"to_v",
"add_k_proj",
"add_q_proj",
"add_v_proj",
"to_out.0",
"to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"proj_mlp",
"proj_out",
]
elif args.flux_lora_target == "ai-toolkit":
# from ostris' ai-toolkit, possibly required to continue finetuning one.
target_modules = [
"to_q",
"to_k",
"to_v",
"add_q_proj",
"add_k_proj",
"add_v_proj",
"to_out.0",
"to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"norm.linear",
"norm1.linear",
"norm1_context.linear",
"proj_mlp",
"proj_out",
]```
### Reproduction
here is the python code:
```python
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from diffusers.utils import load_image
import cv2
import numpy as np
from PIL import Image
import logging
from optimum.quanto import freeze, qfloat8, quantize
import sys
import safetensors.torch
# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Helper function to apply Canny edge detection
def apply_canny(image, low_threshold=100, high_threshold=200):
image = np.array(image)
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return Image.fromarray(image)
# Load the main model
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Loading transformer...")
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
transformer.to(device)
logging.info("Loading text encoder...")
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
text_encoder_2.to(device)
logging.info("Creating pipeline...")
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to(device)
logging.info("Pipeline created")
# Load and apply multiple LoRAs
loras = [
{"path": "v1.safetensors", "alpha": 0.5},
{"path": "v2.safetensors", "alpha": 0.3},
]
pipe.load_lora_weights("./v1",weight_name="v1.safetensors",
adapter_name="default",
cross_attention_kwargs={"scale": 0.5})
pipe.load_lora_weights("./v2",weight_name="v2.safetensors",
adapter_name="default",
cross_attention_kwargs={"scale": 0.5})
logging.info("Loading LoRAs...")
# Prepare the input image and apply Canny edge detection
try:
input_image = load_image("img.png")
control_image = apply_canny(input_image)
logging.info("Input image processed successfully")
except Exception as e:
logging.error(f"Error processing input image: {str(e)}")
sys.exit(1)
# Generate the image
prompt = "A fire inventory"
logging.info("Starting image generation...")
try:
with torch.no_grad():
image = pipe(
prompt,
num_inference_steps=30,
guidance_scale=7.5,
generator=torch.Generator(device=device).manual_seed(0)
).images[0]
logging.info("Image generation completed")
except Exception as e:
logging.error(f"Error during image generation: {str(e)}")
logging.error(f"Error details: {str(e)}", exc_info=True)
sys.exit(1)```
### Logs
```shell
Traceback (most recent call last):
File "/home/axel/Documents/Flux/infc.py", line 82, in <module>
pipe.load_lora_weights("./v1",weight_name="v1.safetensors",
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1647, in load_lora_weights
self.load_lora_into_transformer(
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1736, in load_lora_into_transformer
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 395, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
load(self, state_dict)
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
module._load_from_state_dict(
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
deserialized_weight = QBytesTensor.load_from_state_dict(
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict
inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'
System Info
Latest diffusers cloned 2 hours ago,
Linux debian.
Describe the bug
I trained a lora with simpletuner using ai-toolkit preset (I used all+ffs and others and it doesnt train correctly on hard concepts). And Now I have this issue when loading the lora:
File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load module._load_from_state_dict( File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict deserialized_weight = QBytesTensor.load_from_state_dict( File "/home/axel/Documents/Flux/venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict inner_tensors_dict[name] = state_dict.pop(prefix + name) KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'
From simpletuner code ai-toolkit layer reference:
System Info
Latest diffusers cloned 2 hours ago, Linux debian.
Who can help?
@saya