aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
109 stars 12 forks source link

LoRA loading fails if only trained on specific blocks #13

Open fblissjr opened 1 week ago

fblissjr commented 1 week ago

Just FYI, think this is failing because of a LoRA with only certain blocks trained:

  File "flux-fp8-api/flux_pipeline.py", line 163, in load_lora
    self.model = lora_loading.apply_lora_to_model(self.model, lora_path, scale)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniconda3/envs/flux/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "flux-fp8-api/lora_loading.py", line 398, in apply_lora_to_model
    lora_weights = convert_diffusers_to_flux_transformer_checkpoint(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "flux-fp8-api/lora_loading.py", line 120, in convert_diffusers_to_flux_transformer_checkpoint
    sample_q_A = diffusers_state_dict.pop(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'transformer.transformer_blocks.0.attn.to_q.lora_A.weight'

This is for a LoRA trained on:

transformer.single_transformer_blocks.7.proj_out
transformer.single_transformer_blocks.20.proj_out

I didn't get a chance to test with any that had all layers trained since I didn't have any handy, but looking at the code super quickly, I think it's trying to reference a key that doesn't exist?

Great project btw.

fblissjr commented 1 week ago

Simple check might be something like this, can test later. Or move the layers to a separate mapping file if this ever goes beyond flux1 architecture.

def safe_pop(dict, key, default=None):
    return dict.pop(key, default)

sample_q_A = safe_pop(diffusers_state_dict, f"{prefix}{block_prefix}attn.to_q.lora_B.weight")

edit: looks like it's not that simple, can get it to load now but don't think it's applying the LoRA weights correctly (or at all). Can't tell if it's a data type issue, since I trained this LoRA with ai-toolkit about a week ago.

fblissjr commented 1 week ago

Think I got it working, but tons of debug statements in there to wrap my head around some of the key naming. It now finally uses my LoRA weights (outputs look good too from the image outputs).

Forked here, will PR if helpful once cleaned up unless you prefer to change it a different way, @aredden . Also added an inspect_lora.py to debug.

Branch with updates: https://github.com/fblissjr/flux-fp8-api/tree/lora-loading-layer-fix

fblissjr commented 1 week ago

@aredden - If you're up for making this a pip installable library, I've been creating a bunch of fun toy examples that were really only possible in ComfyUI (at the same quality and speed) until your project. Also adding the ability to capture text embeddings and intermediate embeddings.

Let me know if you're up for collaborating more on these types of things, or if it makes sense to keep it all separate.

Can't say enough good things about this though.

0xtempest commented 1 week ago

@fblissjr I'm having a lot more issues with applying lora's even after applying your changes.

I've got a version working locally that is applying the LoRA weights properly, but I need to test it further.

I'll update and make a PR when I have it fully tested and working, fingers crossed for soon.

Also I'm down to collaborate, my next steps are to get controlnet and inpainting working

ashakoen commented 1 week ago

@fblissjr Thanks for your branch and work on LoRA loading. It's weird, I have issues with some LoRAs and not others. I'll continue to test.

fblissjr commented 1 week ago

@fblissjr Thanks for your branch and work on LoRA loading. It's weird, I have issues with some LoRAs and not others. I'll continue to test.

Same. I’ve become maybe a quarter of the way familiar with the code base now after the weekend, and learning a ton just from working with it. Clean and easy to find what’s happening where.

I think it may be tough to nail down every LoRA problem given there’s no standard way to make one. Doing the conversion from diffusers was a good choice to take care of the 80%.

fblissjr commented 1 week ago

@fblissjr I'm having a lot more issues with applying lora's even after applying your changes.

I've got a version working locally that is applying the LoRA weights properly, but I need to test it further.

I'll update and make a PR when I have it fully tested and working, fingers crossed for soon.

Also I'm down to collaborate, my next steps are to get controlnet and inpainting working

Out of curiosity what did you use to train it?

aredden commented 1 week ago

Yeah it is tough- I know there will probably be issues with the attention block loras if not all of q/k/v are lora'd, since the official flux repo and this repo fuse the q/k/v- so it's hard to apply correctly. I could fix the issue with some time but might be a little while before I get around to it as this is just a small side project for me and I have a full time job I have to dedicate a lot of time to heh.

aredden commented 1 week ago

@aredden - If you're up for making this a pip installable library, I've been creating a bunch of fun toy examples that were really only possible in ComfyUI (at the same quality and speed) until your project. Also adding the ability to capture text embeddings and intermediate embeddings.

Let me know if you're up for collaborating more on these types of things, or if it makes sense to keep it all separate.

Can't say enough good things about this though.

Oh! Yeah that could be cool. If you wanted to make a pull with your changes I could try it out. @fblissjr - I'm definitely up for working on things.

0xtempest commented 1 week ago

I have lora's trained with xlabs working, although I've only tested 4-5 at this point (but they work very well), the only issue is that the lora scale is extremely sensitive

I'll work on fixing that

Here's some examples from their furry lora (lol): Screenshot 2024-09-10 at 4 49 57 AM

Screenshot 2024-09-10 at 4 47 42 AM

I'm getting 6-7.8 it/s on a 4090 as well (depending on size)

0xtempest commented 1 week ago

my local codebase is a bit of a mess right now, I had to make some fairly big-ish changes to lora_loading and other components of the pipeline. Probably good to start the PR's early next week

0xtempest commented 1 week ago

quick and dirty approach may be to just have different pipelines for various popular lora's

fblissjr commented 1 week ago

Yeah it is tough- I know there will probably be issues with the attention block loras if not all of q/k/v are lora'd, since the official flux repo and this repo fuse the q/k/v- so it's hard to apply correctly. I could fix the issue with some time but might be a little while before I get around to it as this is just a small side project for me and I have a full time job I have to dedicate a lot of time to heh.

This is one of the bigger variables I've noticed, thanks for calling it out specifically. Whatever dtypes of each part you train your lora at (and if you're using a service or spaghetti code, you may not even know), it needs to match pretty closely on inference.

If I can get the time to allow more granular data type settings of each q/k/v of each block in, I'll add it to the branch and PR it.

This is probably what most people are seeing issues with.

Also: if you use an fp8 pre-quant of the flux model to fp8 vs. runtime quant of the original, it will also need to match.

fblissjr commented 1 week ago

I have lora's trained with xlabs working, although I've only tested 4-5 at this point (but they work very well), the only issue is that the lora scale is extremely sensitive

I'll work on fixing that

Here's some examples from their furry lora (lol): Screenshot 2024-09-10 at 4 49 57 AM

Screenshot 2024-09-10 at 4 47 42 AM

I'm getting 6-7.8 it/s on a 4090 as well (depending on size)

from what i've briefly read (havent dug into xlabs much), they use a very different process. not that there's any sort of standard, but it's the less standard way. :)

aredden commented 1 week ago

One thing I was having trouble with were dora's. I have no idea how to use those. I tried looking at some other codebases to figure out how they handle them, but- when there's no 'dora scale' nor alpha, it makes it hard to figure out what should be happening.

fblissjr commented 1 week ago

One thing I was having trouble with were dora's. I have no idea how to use those. I tried looking at some other codebases to figure out how they handle them, but- when there's no 'dora scale' nor alpha, it makes it hard to figure out what should be happening.

There's even less of a standard on that right now. Not sure I'd start with dora. The problem with the lack of alpha is coming from diffusers, which is what ai-toolkit (which uses kohya scripts) uses.

I think some of the best work figuring this all out has come from kijai's comfy node: https://github.com/kijai/ComfyUI-FluxTrainer

And very excited to get a chance to dig into mflux for MLX: https://github.com/filipstrand/mflux

fblissjr commented 6 days ago

@aredden do you want me to make a PR for the lora loading blocks that are null?

also feel 100% free to grab anything from my branch and modify as you like if it helps.

0xtempest commented 13 hours ago

I made a new server if you guys wanna sync on discord:

https://discord.gg/BYJxEUE3Rx

Could be fun to sync on cool new features and bounce ideas around for implementation details.