KohakuBlueleaf / LyCORIS

Lora beYond Conventional methods, Other Rank adaptation Implementations for Stable diffusion.
Apache License 2.0
2.2k stars 152 forks source link

GLoRA inference fails with Flux due to weights being in bfloat16 precision #215

Open mhirki opened 1 month ago

mhirki commented 1 month ago

My simple inference script is failing when calling wrapper.merge_to() with Flux Dev as the base model.

2024-09-21 19:27:53|[LyCORIS]-INFO: Loading Modules from state dict...
2024-09-21 19:27:54|[LyCORIS]-INFO: 504 Modules Loaded
Traceback (most recent call last):
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/output/models.bak_flux_sanna_marin_v0.4_fp8_multires_adan3_glora/inference2.py", line 12, in <module>
    wrapper.merge_to(0.5)
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/wrapper.py", line 567, in merge_to
    lora.merge_to(weight)
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/base.py", line 269, in merge_to
    weight, bias = self.get_merged_weight(
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 208, in get_merged_weight
    diff_w, _ = self.get_diff_weight(multiplier, shape, device)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 202, in get_diff_weight
    weight = self.make_weight(device) * multiplier
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 198, in make_weight
    w_wa2 = (orig @ wa1) @ wa2
             ~~~~~^~~~~
RuntimeError: expected m1 and m2 to have the same dtype, but got: c10::BFloat16 != float

orig is in bfloat16 precision while wa1 and wa2 are in float precision. I tried both upcasting orig and downcasting wa1 and wa2 and there was very little difference in the end result. Upcasting to float precision did run much faster on cpu. I'm not sure which way you prefer to solve this so I'm posting this as an issue.

Here's my inference script for reference:

import torch
from diffusers import FluxPipeline
from lycoris import create_lycoris_from_weights

torch.set_num_threads(16)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

adapter_id = 'pytorch_lora_weights_fixed.safetensors' # you will have to download this manually
lora_scale = 1
wrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipe.transformer)
wrapper.merge_to(0.5)

pipe.enable_sequential_cpu_offload()

prompt = "sanna marin playing tennis"
generator = torch.Generator().manual_seed(1000)
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1280,
    width=832,
    num_inference_steps=20,
    generator=generator
).images[0]
out.save("image.png")
KohakuBlueleaf commented 1 month ago

will implement some type checks You can do this as workaround:

wrapper.apply_to()
wrapper.to(device, dtype)
wrapper.restore()
wrapper.merge_to()