Open nelapetrzelkova opened 1 month ago
Hello, I am facing an issue with generating images with FLUX.1[dev] + LoRA that I trained with SimpleTuner. I need to be able to load the LoRAs dynamically, therefore I want to use the already quantized FLUX before the LoRA is loaded into it. With optimum-quanto version 0.2.4 and lower I got the following error:
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data’
. After bumping the version to 0.2.5 or 0.2.6, no error is thrown but the results look like this:My code:
import torch from diffusers import DiffusionPipeline from optimum.quanto import freeze, qfloat8, quantize model_id = 'black-forest-labs/FLUX.1-dev' pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) quantize(pipeline.transformer, weights=qfloat8) freeze(pipeline.transformer) pipeline.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') lora_path = <path_to_lora> pipeline.load_lora_weights(lora_path) prompts = {"candy": "Candy bar surrounded by playful, abstract shapes resembling candy sprinkles and whimsical clouds of cream. The atmosphere is vibrant and joyful, filled with bright colors that evoke childhood memories of sweetness and fun. This imagery invites viewers to imagine the delight of savoring a piece of chocolate that brings happiness to any moment."} seed = 19640904 for prompt_key, prompt_value in prompts.items(): print(prompt_key, prompt_value) images = pipeline( prompt=prompt_value, num_inference_steps=10, num_images_per_prompt=1, generator=torch.Generator(device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').manual_seed(seed), width=1024, height=1024, ).images # get key of the prompt for idx, image in enumerate(images): display(image)
Is there a way how to solve this? A workaround could be to load the LoRA into the model before quantization and save the quantized merged model and work with that, but I lose the benefit of working with the LoRA only, which is much faster and less memory expensive.
Thanks!
I encountered a similar issue. When using optimum-quanto==0.2.6 to quantize FLUX.1-schnell, the output also turned into random noise. After investigating, I found that the issue was caused by MarlinF8QBytesTensor.
To fix it, you can modify optimum/quanto/tensor/weights/qbytes.py.
Simply change the line:
and torch.cuda.get_device_capability(data.device)[0] >= 8
to:
and torch.cuda.get_device_capability(data.device)[0] >= 20
This resolved the problem for me.
if (
qtype == qtypes["qfloat8_e4m3fn"]
and activation_qtype is None
and scale.dtype in [torch.float16, torch.bfloat16]
and len(size) == 2
and (data.device.type == "cuda" and torch.version.cuda)
and axis == 0
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
out_features, in_features = size
if (
in_features >= 64
and out_features >= 64
and (
(in_features % 64 == 0 and out_features % 128 == 0)
or (in_features % 128 == 0 and out_features % 64 == 0)
)
):
return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad)
But I don't know why MarlinF8QBytesTensor can‘t work. @dacorvo
@tyyff thank you for investigating this. See also #332. There might be a general issue with Marlin kernels when the size of the tensors involved in the matmul increases (could be an overflow, could be some overlaps in the intermediate result buffers, I really don't know). I will disable the FP8 kernel for now.
Hello, I am facing an issue with generating images with FLUX.1[dev] + LoRA that I trained with SimpleTuner. I need to be able to load the LoRAs dynamically, therefore I want to use the already quantized FLUX before the LoRA is loaded into it. With optimum-quanto version 0.2.4 and lower I got the following error:
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data’
. After bumping the version to 0.2.5 or 0.2.6, no error is thrown but the results look like this:My code:
Is there a way how to solve this? A workaround could be to load the LoRA into the model before quantization and save the quantized merged model and work with that, but I lose the benefit of working with the LoRA only, which is much faster and less memory expensive.
Thanks!