My simple inference script is failing when calling wrapper.merge_to() with Flux Dev as the base model.
2024-09-21 19:27:53|[LyCORIS]-INFO: Loading Modules from state dict...
2024-09-21 19:27:54|[LyCORIS]-INFO: 504 Modules Loaded
Traceback (most recent call last):
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/output/models.bak_flux_sanna_marin_v0.4_fp8_multires_adan3_glora/inference2.py", line 12, in <module>
wrapper.merge_to(0.5)
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/wrapper.py", line 567, in merge_to
lora.merge_to(weight)
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/base.py", line 269, in merge_to
weight, bias = self.get_merged_weight(
^^^^^^^^^^^^^^^^^^^^^^^
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 208, in get_merged_weight
diff_w, _ = self.get_diff_weight(multiplier, shape, device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 202, in get_diff_weight
weight = self.make_weight(device) * multiplier
^^^^^^^^^^^^^^^^^^^^^^^^
File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 198, in make_weight
w_wa2 = (orig @ wa1) @ wa2
~~~~~^~~~~
RuntimeError: expected m1 and m2 to have the same dtype, but got: c10::BFloat16 != float
orig is in bfloat16 precision while wa1 and wa2 are in float precision. I tried both upcasting orig and downcasting wa1 and wa2 and there was very little difference in the end result. Upcasting to float precision did run much faster on cpu. I'm not sure which way you prefer to solve this so I'm posting this as an issue.
Here's my inference script for reference:
import torch
from diffusers import FluxPipeline
from lycoris import create_lycoris_from_weights
torch.set_num_threads(16)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
adapter_id = 'pytorch_lora_weights_fixed.safetensors' # you will have to download this manually
lora_scale = 1
wrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipe.transformer)
wrapper.merge_to(0.5)
pipe.enable_sequential_cpu_offload()
prompt = "sanna marin playing tennis"
generator = torch.Generator().manual_seed(1000)
out = pipe(
prompt=prompt,
guidance_scale=3.5,
height=1280,
width=832,
num_inference_steps=20,
generator=generator
).images[0]
out.save("image.png")
My simple inference script is failing when calling wrapper.merge_to() with Flux Dev as the base model.
orig
is in bfloat16 precision whilewa1
andwa2
are in float precision. I tried both upcastingorig
and downcastingwa1
andwa2
and there was very little difference in the end result. Upcasting to float precision did run much faster on cpu. I'm not sure which way you prefer to solve this so I'm posting this as an issue.Here's my inference script for reference: