aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
209 stars 22 forks source link

How to save a "prequantized_flow" safetensor? #16

Open smuelpeng opened 2 months ago

smuelpeng commented 2 months ago

Hello,

The documentation mentions that the --prequantized-flow option can be used to load a prequantized model, which reduces the checkpoint size by about 50% and shortens the startup time (default: False).

However, I couldn’t find any interface in the repository to enable this functionality. Could you please provide guidance on how to store and load a prequantized model to save resources and initialization time?

Looking forward to your response, thank you!

aredden commented 2 months ago

Ah! Essentially it's just the checkpoint which gets created after loading the model and doing at least 12 steps of inference. You could do something like this in the root of the repo-

from flux_pipeline import FluxPipeline, ModelVersion
from safetensors.torch import save_file
prompt = "some prompt"
pipe = FluxPipeline.load_pipeline_from_config_path("./configs/your-config.json")
if pipe.config.version == ModelVersion.flux_schnell:
    for x in range(3):
        pipe.generate(prompt=prompt, num_steps=4)
else:
    pipe.generate(prompt=prompt, num_steps=12)

quantized_state_dict = pipe.model.state_dict()

save_file(quantized_state_dict, "some-model-prequantized.safetensors")
smuelpeng commented 2 months ago

Thank you for your helpful response. The solution works well for loading pre-quantized SFTs.

However, do you have any suggestions for saving and loading a Torch-compiled Flux model? Currently, the initialization time for compiling the Flux model is quite cumbersome, and I’m looking for ways to streamline this process.

aredden commented 2 months ago

Ah- You can speed that up by using nightly torch- for me compilation only takes a few (maybe 3-4) seconds at most.

Muawizodux commented 1 month ago

I appreciate your amazing work!

for me torch-nightly takes 9-18 sec per inference on first 3 warm-up inferences and torch takes 1-1.5 minites per inference on first 3 inferences

am i missing something?

aredden commented 1 month ago

That seems correct, it's possible that it's just related to the cpu- I have a 7950x so everything runs very fast.