aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
209 stars 22 forks source link
diffusion fast-inference flux fp8 pytorch quantization

Flux FP8 (true) Matmul Implementation with FastAPI

This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. And also a simple single line of code to use the generator as a single object, similar to diffusers pipelines.

Speed Comparison

Note:

Resolution Device Test Average it/s
1024x1024 RTX4090 bfl codebase fp8 wo quant 1.7
1024x1024 RTX4090 ❌ compile blocks & extras 2.55
1024x1024 RTX4090 ✅ compile blocks & extras 3.51
1024x1024 RTX4000ADA ❌ compile blocks & extras 0.79
1024x1024 RTX4000ADA ✅ compile blocks & extras 1.26
1024x1024 RTX6000ADA bfl codebase 1.74
1024x1024 RTX6000ADA ❌ compile blocks & extras 2.08
1024x1024 RTX6000ADA ✅ compile blocks & extras 2.8
1024x1024 H100 ❌ compile blocks & extras 6.1
1024x1024 H100 ✅ compile blocks & extras 11.5
768x768 RTX4090 bfl codebase fp8 wo quant 2.32
768x768 RTX4090 ❌ compile blocks & extras 4.47
768x768 RTX4090 ✅ compile blocks & extras 6.2
768x768 RTX4000 ❌ compile blocks & extras 1.41
768x768 RTX4000 ✅ compile blocks & extras 2.19
768x768 RTX6000ADA bfl codebase 3.01
768x768 RTX6000ADA ❌ compile blocks & extras 3.43
768x768 RTX6000ADA ✅ compile blocks & extras 4.46
768x768 H100 ❌ compile blocks & extras 10.3
768x768 H100 ✅ compile blocks & extras 20.8
1024x720 RTX4090 bfl codebase fp8 wo quant 3.01
1024x720 RTX4090 ❌ compile blocks & extras 3.6
1024x720 RTX4090 ✅ compile blocks & extras 4.96
1024x720 RTX4000 ❌ compile blocks & extras 1.14
1024x720 RTX4000 ✅ compile blocks & extras 1.78
1024x720 RTX6000ADA bfl codebase 2.37
1024x720 RTX6000ADA ❌ compile blocks & extras 2.87
1024x720 RTX6000ADA ✅ compile blocks & extras 3.78
1024x720 H100 ❌ compile blocks & extras 8.2
1024x720 H100 ✅ compile blocks & extras 15.7

Table of Contents

Updates 08/24/24

Fixes

note: prequantized flow models will only work with the specified quantization levels as when they were created. e.g. if you create a prequantized flow model with quantize_modulation set to false, it will only work with quantize_modulation set to false, same with quantize_flow_embedder_layers.

Updates 08/25/24

from flux_pipeline import FluxPipeline

config_path = "path/to/config/file.json"
config_overrides = {
    #...
}

lora_path = "path/to/lora/file.safetensors"

pipeline = FluxPipeline.load_pipeline_from_config_path(config_path, **config_overrides)

pipeline.load_lora(lora_path, scale=1.0)

Updates 09/07/24

Updates 10/3/24

Installation

This repo requires at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise torch._scaled_mm will throw a CUDA error saying it's not supported. To install with conda/mamba:

mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
mamba activate flux-fp8-matmul-api

# or with conda
conda create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
conda activate flux-fp8-matmul-api

# or with nightly... (which is what I am using) - also, just switch 'mamba' to 'conda' if you are using conda
mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia
mamba activate flux-fp8-matmul-api

# or with pip
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# or pip nightly
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

To install the required dependencies, run:

python -m pip install -r requirements.txt

If you get errors installing torch-cublas-hgemm, feel free to comment it out in requirements.txt, since it's not necessary, but will speed up inference for non-fp8 linear layers.

Usage

For a single ADA GPU with less than 24GB vram, and more than 16GB vram, you should use the configs/config-dev-offload-1-4080.json config file as a base, and then tweak the parameters to fit your needs. It offloads all models to CPU when not in use, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to qfloat8.

For a single ADA GPU with more than ~32GB vram, you should use the configs/config-dev-1-RTX6000ADA.json config file as a base, and then tweak the parameters to fit your needs. It does not offload any models to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to qfloat8 and the autoencoder to stays as bfloat16.

For a single 4090 GPU, you should use the configs/config-dev-offload-1-4090.json config file as a base, and then tweak the parameters to fit your needs. It offloads the text encoder and the autoencoder to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to float8.

NOTE: For all of these configs, you must change the ckpt_path, ae_path, and text_enc_path parameters to the path to your own checkpoint, autoencoder, and text encoder.

You can run the API server using the following command:

python main.py --config-path <path_to_config> --port <port_number> --host <host_address>

API Command-Line Arguments

Configuration

The configuration files are located in the configs directory. You can specify different configurations for different model versions and devices.

Example configuration file for a single 4090 (configs/config-dev-offload-1-4090.json):

{
    "version": "flux-dev", // or flux-schnell
    "params": {
        "in_channels": 64,
        "vec_in_dim": 768,
        "context_in_dim": 4096,
        "hidden_size": 3072,
        "mlp_ratio": 4.0,
        "num_heads": 24,
        "depth": 19,
        "depth_single_blocks": 38,
        "axes_dim": [16, 56, 56],
        "theta": 10000,
        "qkv_bias": true,
        "guidance_embed": true // if you are using flux-schnell, set this to false
    },
    "ae_params": {
        "resolution": 256,
        "in_channels": 3,
        "ch": 128,
        "out_ch": 3,
        "ch_mult": [1, 2, 4, 4],
        "num_res_blocks": 2,
        "z_channels": 16,
        "scale_factor": 0.3611,
        "shift_factor": 0.1159
    },
    "ckpt_path": "/your/path/to/flux1-dev.sft", // local path to original bf16 BFL flux checkpoint
    "ae_path": "/your/path/to/ae.sft", // local path to original bf16 BFL autoencoder checkpoint
    "repo_id": "black-forest-labs/FLUX.1-dev", // can ignore
    "repo_flow": "flux1-dev.sft", // can ignore
    "repo_ae": "ae.sft", // can ignore
    "text_enc_max_length": 512, // use 256 if you are using flux-schnell
    "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", // or custom HF full bf16 T5EncoderModel repo id
    "text_enc_device": "cuda:0",
    "ae_device": "cuda:0",
    "flux_device": "cuda:0",
    "flow_dtype": "float16",
    "ae_dtype": "bfloat16",
    "text_enc_dtype": "bfloat16",
    "flow_quantization_dtype": "qfloat8", // will always be qfloat8, so can ignore
    "text_enc_quantization_dtype": "qint4", // choose between qint4, qint8, qfloat8, qint2 or delete entry for no quantization
    "ae_quantization_dtype": "qfloat8", // can either be qfloat8 or delete entry for no quantization
    "compile_extras": true, // compile the layers not included in the single-blocks or double-blocks
    "compile_blocks": true, // compile the single-blocks and double-blocks
    "offload_text_encoder": true, // offload the text encoder to cpu when not in use
    "offload_vae": true, // offload the autoencoder to cpu when not in use
    "offload_flow": false, // offload the flow transformer to cpu when not in use
    "prequantized_flow": false, // load the flow transformer from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: false)
    "quantize_modulation": true, // quantize the modulation layers in the flow transformer, which reduces precision moderately but saves ~2GB vram usage (default: true)
    "quantize_flow_embedder_layers": false, // quantize the flow embedder layers in the flow transformer, if false, improves precision considerably at the cost of adding ~512MB vram usage (default: false)
}

The only things you should need to change in general are the:

    "ckpt_path": "/path/to/your/flux1-dev.sft", // path to your original BFL flow transformer (not diffusers)
    "ae_path": "/path/to/your/ae.sft", // path to your original BFL autoencoder (not diffusers)
    "text_enc_path": "path/to/your/t5-v1_1-xxl-encoder-bf16", // HF T5EncoderModel - can use "city96/t5-v1_1-xxl-encoder-bf16" for a simple to download version

Other things to change can be the

API Endpoints

Generate Image

Examples

Running the Server

python main.py --config-path configs/config-dev-1-4090.json --port 8088 --host 0.0.0.0

Or if you need more granular control over the all of the settings, you can run the server with something like this:

python main.py --port 8088 --host 0.0.0.0 \
    --flow-model-path /path/to/your/flux1-dev.sft \
    --text-enc-path /path/to/your/t5-v1_1-xxl-encoder-bf16 \
    --autoencoder-path /path/to/your/ae.sft \
    --model-version flux-dev \
    --flux-device cuda:0 \
    --text-enc-device cuda:0 \
    --autoencoder-device cuda:0 \
    --compile \
    --quant-text-enc qfloat8 \
    --quant-ae

Generating an image on a client

Send a POST request to http://<host>:<port>/generate with the following JSON body:

{
    "prompt": "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
    "width": 1024,
    "height": 1024,
    "num_steps": 24,
    "guidance": 3.0,
    "seed": 13456
}

For an example of how to generate from a python client using the FastAPI server:

import requests
import io

prompt = "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns"
res = requests.post(
    "http://localhost:8088/generate",
    json={
        "width": 1024,
        "height": 720,
        "num_steps": 20,
        "guidance": 4,
        "prompt": prompt,
    },
    stream=True,
)

with open(f"output.jpg", "wb") as f:
    f.write(io.BytesIO(res.content).read())

You can also generate an image by directly importing the FluxPipeline class and using it to generate an image. This is useful if you have a custom model configuration and want to generate an image without having to run the server.

import io
from flux_pipeline import FluxPipeline

pipe = FluxPipeline.load_pipeline_from_config_path(
    "configs/config-dev-offload-1-4090.json"  # or whatever your config is
)

output_jpeg_bytes: io.BytesIO = pipe.generate(
    # Required args:
    prompt="A beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
    # Optional args:
    width=1024,
    height=1024,
    num_steps=20,
    guidance=3.5,
    seed=13456,
    init_image="path/to/your/init_image.jpg",
    strength=0.8,
)

with open("output.jpg", "wb") as f:
    f.write(output_jpeg_bytes.getvalue())