ostris / ai-toolkit

Various AI scripts. Mostly Stable Diffusion stuff.
MIT License
3.31k stars 337 forks source link

How to merge trained FLUX LoRA into FLUX Dev Model #213

Open wiggin66 opened 22 hours ago

wiggin66 commented 22 hours ago

I would like to know how to merge [ai-toolkit] trained LoRA back into the original model.

If there are existing scripts or implementations in other repositories for this merging purpose.

DopeDaditude commented 8 hours ago

It is pretty much this:

        if lora_cache_dir:
            lora_path = os.path.join(lora_cache_dir, "lora.safetensors")
            if os.path.exists(lora_path):
                logging.info(f"Loading custom LoRA for user: {username}")
                pipe.load_lora_weights(lora_path, adapter_name=username)
                pipe.fuse_lora(lora_scale=lora_scale)
                logging.info(f"Custom LoRA loaded and fused successfully with scale {lora_scale}")
            else:
                logging.warning(f"No custom LoRA found for user: {username}. Using base model.")
DopeDaditude commented 8 hours ago

It's best to just ask the ai how to do it

wiggin66 commented 5 hours ago

``> It is pretty much this:

        if lora_cache_dir:
            lora_path = os.path.join(lora_cache_dir, "lora.safetensors")
            if os.path.exists(lora_path):
                logging.info(f"Loading custom LoRA for user: {username}")
                pipe.load_lora_weights(lora_path, adapter_name=username)
                pipe.fuse_lora(lora_scale=lora_scale)
                logging.info(f"Custom LoRA loaded and fused successfully with scale {lora_scale}")
            else:
                logging.warning(f"No custom LoRA found for user: {username}. Using base model.")

Thank you for your advice.

But I would like to inquire about how to load a LoRA model and save it to a local file. This is for the purpose of model quantization and deployment.

Because i failed to use optimum to quantize flux merged lora model, using code something like :

from diffusers import FluxPipeline
from optimum.quanto import freeze, qint8, quantize, quantization_map

bfl_repo = './new_model'
pipe = FluxPipeline.from_pretrained('')
adapter_id = ''
pipe.load_lora_weights(adapter_id)
pipe.fuse_lora(lora_scale=1.0)
pipe.save_pretrained(bfl_repo)

transformer = FluxTransformer2DModel.from_pretrained(bfl_repo)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2