aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
108 stars 12 forks source link

Potential LoRA performance issue #9

Open ashakoen opened 2 weeks ago

ashakoen commented 2 weeks ago

Thought I'd raise this in case there's an issue.

Steps to reproduce:

1 - create FLUX.1 LoRA fine-tune at Replicate 2 - generate images using Replicate's FLUX.1[dev] API using the FLUX.1 LoRA 3 - result: perfect face matching to LoRA training 3 - load same LoRA using LoRA load techniques in README.md for flux-fp8-api 4 - generate image with same prompt and parameters using flux-fp8-api 5 - result: very poor face matching to LoRA training

Any suggestions for what I might try?

Here's how I implemented LoRA loading in main.py:

--- main.orig.py        2024-08-29 15:34:42.612578339 +0200
+++ main.py     2024-08-29 14:53:52.603088816 +0200
@@ -2,9 +2,9 @@
 import uvicorn
 from api import app

-
 def parse_args():
     parser = argparse.ArgumentParser(description="Launch Flux API server")
+    # Existing arguments...
     parser.add_argument(
         "-c",
         "--config-path",
@@ -145,9 +145,17 @@
         dest="quantize_flow_embedder_layers",
         help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
     )
+    
+    # New arguments for LoRA loading
+    parser.add_argument(
+        "-L", "--lora-paths", type=str, help="Comma-separated paths to LoRA checkpoint files"
+    )
+    parser.add_argument(
+        "-S", "--lora-scales", type=str, default="1.0", help="Comma-separated scales for each LoRA"
+    )
+    
     return parser.parse_args()

-
 def main():
     args = parse_args()

@@ -192,8 +200,16 @@
         )
         app.state.model = FluxPipeline.load_pipeline_from_config(config)

-    uvicorn.run(app, host=args.host, port=args.port)
+    # If LoRA paths are provided, apply them sequentially
+    if args.lora_paths:
+        lora_paths = args.lora_paths.split(',')
+        lora_scales = [float(scale) for scale in args.lora_scales.split(',')] if args.lora_scales else [1.0] * len(lora_paths)
+        
+        # Apply each LoRA sequentially
+        for lora_path, scale in zip(lora_paths, lora_scales):
+            app.state.model.load_lora(lora_path, scale=scale)

+    uvicorn.run(app, host=args.host, port=args.port)

 if __name__ == "__main__":
     main()

And I call main.py like this:

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python main.py --config-path configs/config-dev-offload-1-4080.json --port 7888 --host <IP> --lora-scale 1 --lora-path /root/flux-fp8-api.working/models/lora.safetensors
aredden commented 2 weeks ago

Ah that's interesting. Are you using the latest code? There was a bug earlier where it was always setting the lora alpha to 1.0 for huggingface diffusers loras. Though it could be something else.

ashakoen commented 2 weeks ago

Thanks for the reply! Yes, I'm using the latest code.

aredden commented 2 weeks ago

It's possible that there are some lora loading specifics that I didn't implement well- but I'm not entire sure what that would be. I will have to look into other lora implementations.

armheb commented 1 week ago

Thanks for your amazing work. I also had some issues with the Lora.

Here's the config I used for 4090:

{
  "version": "flux-dev",
  "params": {
    "in_channels": 64,
    "vec_in_dim": 768,
    "context_in_dim": 4096,
    "hidden_size": 3072,
    "mlp_ratio": 4.0,
    "num_heads": 24,
    "depth": 19,
    "depth_single_blocks": 38,
    "axes_dim": [
      16,
      56,
      56
    ],
    "theta": 10000,
    "qkv_bias": true,
    "guidance_embed": true
  },
  "ae_params": {
    "resolution": 256,
    "in_channels": 3,
    "ch": 128,
    "out_ch": 3,
    "ch_mult": [
      1,
      2,
      4,
      4
    ],
    "num_res_blocks": 2,
    "z_channels": 16,
    "scale_factor": 0.3611,
    "shift_factor": 0.1159
  },
  "ckpt_path": "flux1-dev.safetensors",
  "ae_path": "ae.safetensors",
  "repo_id": "black-forest-labs/FLUX.1-dev",
  "repo_flow": "flux1-dev.safetensors",
  "repo_ae": "ae.safetensors",
  "text_enc_max_length": 512,
  "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
  "text_enc_device": "cuda:0",
  "ae_device": "cuda:0",
  "flux_device": "cuda:0",
  "flow_dtype": "float16",
  "ae_dtype": "bfloat16",
  "text_enc_dtype": "bfloat16",
  "flow_quantization_dtype": "qfloat8",
  "text_enc_quantization_dtype": "qint4",
  "ae_quantization_dtype": "qfloat8",
  "compile_extras": true,
  "compile_blocks": true,
  "offload_text_encoder": true,
  "offload_vae": true,
  "offload_flow": false
}

config I used for H100:

{
  "version": "flux-dev",
  "params": {
    "in_channels": 64,
    "vec_in_dim": 768,
    "context_in_dim": 4096,
    "hidden_size": 3072,
    "mlp_ratio": 4.0,
    "num_heads": 24,
    "depth": 19,
    "depth_single_blocks": 38,
    "axes_dim": [
      16,
      56,
      56
    ],
    "theta": 10000,
    "qkv_bias": true,
    "guidance_embed": true
  },
  "ae_params": {
    "resolution": 256,
    "in_channels": 3,
    "ch": 128,
    "out_ch": 3,
    "ch_mult": [
      1,
      2,
      4,
      4
    ],
    "num_res_blocks": 2,
    "z_channels": 16,
    "scale_factor": 0.3611,
    "shift_factor": 0.1159
  },
  "ckpt_path": "flux1-dev.safetensors",
  "ae_path": "ae.safetensors",
  "repo_id": "black-forest-labs/FLUX.1-dev",
  "repo_flow": "flux1-dev.safetensors",
  "repo_ae": "ae.safetensors",
  "text_enc_max_length": 512,
  "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
  "text_enc_device": "cuda:0",
  "ae_device": "cuda:0",
  "flux_device": "cuda:0",
  "flow_dtype": "float16",
  "ae_dtype": "bfloat16",
  "text_enc_dtype": "bfloat16",
  "flow_quantization_dtype": "qfloat8",
  "text_enc_quantization_dtype": "qint4",
  "ae_quantization_dtype": "qfloat8",
  "compile_extras": true,
  "compile_blocks": true,
  "offload_text_encoder": false,
  "offload_vae": false,
  "offload_flow": false
}

Hope it's helpful, thanks.

aredden commented 1 week ago

If you're getting black images I would recommend setting flow_dtype to bfloat16, it should help a bit. I'm still a bit unsure as to how I am supposed to handle lora alphas when it's not given in a lora's state dict, since I believe different trainers use different values and I have no idea which is which by default haha.. Sorry 😢

ashakoen commented 1 week ago

How can I help? Maybe I can't, but thought I'd offer.

aredden commented 1 week ago

Thanks 😄 - well if you find anywhere in my lora loading implementation here https://github.com/aredden/flux-fp8-api/blob/main/lora_loading.py let me know and I'll change it, or you can submit a pull request and I'll look it over. Up to you 😄