huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.13k stars 5.19k forks source link

Support dynamic LoRA loading with `torch.compile` model #9279

Open apolinario opened 2 weeks ago

apolinario commented 2 weeks ago

Is your feature request related to a problem? Please describe. Would be great to be able to load a LoRA to a model compiled with torch.compile

Describe the solution you'd like. Do load_lora_weights with a compiled pipe (ideally without triggering recompilation)

Currently, running this code:

import torch
from diffusers import DiffusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

pipe.load_lora_weights("multimodalart/flux-tarot-v1")

It errors:

Loading adapter weights from state_dict led to unexpected keys not found in the model:  ['single_transformer_blocks.0.attn.to_k.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_k.lora_B.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_B.default_3.weight',

When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys

odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight',...

Describe alternatives you've considered. An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)

Additional context. This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)

apolinario commented 2 weeks ago

Some references:

sayakpaul commented 2 weeks ago

@apolinario thanks for the detailed thread. Would be also nice to include a small snippet that we could quickly verify. And yes, ccing @BenjaminBossan for his comments here.

I assume we could just add checks in load_lora_weights() just before when we pass the adapter weights to the underlying model: https://github.com/huggingface/diffusers/blob/4cfb2164fb05d54dd594373b4bd1fbb101fef70c/src/diffusers/loaders/lora_pipeline.py#L1218

BenjaminBossan commented 2 weeks ago

I have not dealt with hot-swapping LoRA weights on compiled models in PEFT, so I'm not surprised that it doesn't work out of the box. The PEFT experiments with compiled models all apply the compilation step after loading the PEFT weights. Maybe it's as easy as remapping the state dict and then calling set_peft_model_state_dict, but I wouldn't be surprised if there are more pitfalls related to torch.compile implementation details.

I'll make a note to look into this on the PEFT side when I have a bit of extra time. But do let me know if you make any progress.

apolinario commented 2 weeks ago

Would be also nice to include a small snippet that we could quickly verify

Added!

sayakpaul commented 2 weeks ago

Applying this change seems to work: https://github.com/huggingface/diffusers/commit/0e7204abcdd1418645e85c541c15da0dfbbbd410

However, I get the following when doing torch.compile() with mode="max-autotune", fullgraph=True:

skipping cudagraphs due to skipping cudagraphs due to cpu device (arg25_1). Found from : 
   File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 494, in forward
    encoder_hidden_states, hidden_states = block(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 165, in forward
    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
  File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 137, in forward
    emb = self.linear(self.silu(emb))
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/peft/tuners/lora/layer.py", line 509, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
apolinario commented 2 weeks ago

Cool, https://github.com/huggingface/diffusers/commit/0e7204abcdd1418645e85c541c15da0dfbbbd410 works here too for reduce-overhead

However it triggers a recompilation when you load/swap a LoRA:

import torch
from diffusers import DiffusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

prompt = "a photo of an astronaut riding a horse on mars"
#This will compile for the first time
image = pipe(prompt).images[0]

pipe.load_lora_weights("multimodalart/flux-tarot-v1")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
#This will re-compile
image = pipe(prompt).images[0]

pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
#This will re-compile again
image = pipe(prompt).images[0]
sayakpaul commented 2 weeks ago

Oh okay. Let's perhaps discuss this with the PyTorch team.

anijain2305 commented 2 weeks ago

Recompiles are coming from this

V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles]     - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 81088688)  # emb = self.linear(self.silu(emb))  # diffusers/src/diffusers/models/normalization.py:137 in forward

On the first compile, TorchDynamo guarded on the id(type(torch.nn.Linear)). But loading weights has somehow changed the type of the linear layer (or atleast the id of the type of class). Do you know if loading weights of the linear changes its class type (of if we are dynamically creating a new class)?

sayakpaul commented 2 weeks ago

Yeah it does.

When we're hitting https://github.com/huggingface/diffusers/blob/4cfb2164fb05d54dd594373b4bd1fbb101fef70c/src/diffusers/loaders/lora_pipeline.py#L1217

The codepath in peft is

https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/mapping.py#L223

To get around the problem, we can likely just fuse the adapter modules into the respective base model modules to prevent recompilation but this is somewhat a bit restrictive in terms of a smoother UX. Perhaps @apolinario can explain this a bit better. So, we wanted to see if there's any alternative we could try here.

Cc: @BenjaminBossan

anijain2305 commented 2 weeks ago

No easy alternative I can think of as of now. torch.compile is adding a guard on the type. In this case, it seems that the type change is benign, but in general the type change requires a recompilation.

Is it possible to somehow ensure that type of the class is same as before after all the weights have been loaded? If we can do that, there wont be any recompiles. A bad way would be to monkeypatch the __class__ itself to the old type.

Chillee commented 2 weeks ago

If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.

The other thing is that if you call model.compile() (as opposed to torch.compile(model)), the state dict won't be modified.

apolinario commented 2 weeks ago

If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.

Agreed. I think loading the first LoRA triggering recompilation is probably fine. However, hot swapping different LoRAs and having it not recompile would be very good to allow for dynamic apps that change LoRAs to benefit from compilation performances (as waiting for it to compile every time LoRAs are swapped would not allow for a live swapping application)

anijain2305 commented 2 weeks ago

I see. In that case, can someone run TORCH_LOGS="guards,recompiles" and share the log? For some reason, my run fails at


pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
sayakpaul commented 2 weeks ago

I am getting you the log but recompilation upon swapping or loading a new LoRA also seems reasonable to me:

Both of these combined should lead to different operations, I'd imagine.

sayakpaul commented 2 weeks ago

@anijain2305 here you go: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/regular_lora_compile.txt

I ran @apolinario's code here from the lora-compile branch of diffusers. I am on torch 2.5.0.dev20240827+cu121. nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:97:00.0 Off |                    0 |
| N/A   36C    P0              68W / 700W |      2MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

LMK if you need more information.

BenjaminBossan commented 2 weeks ago

Note that when PEFT loads the 2nd LoRA adapter, no layers should be swapped. Instead, PEFT will detect that LoRA layers are already in place and will instead update the nn.ModuleDict of those LoRA layers to contain the newly loaded weights. However, the decomposed A and B LoRA weights are implemented as nn.Linear layers, so each newly loaded LoRA adds new layers to the nn.ModuleDict, not just new weights. I assume that this is what trips the guard.

If this is indeed the cause, the whole loading procedure probably needs to be rewritten to directly overwrite the weight.data of the first LoRA weight, without going through PEFT, in order to avoid what I just mentioned.

anijain2305 commented 2 weeks ago

Thanks @sayakpaul

These are the recompilation reasons

 [0/2] [__recompiles]     triggered by the following guard failure(s):
 [0/2] [__recompiles]     - 0/1: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 [0/2] [__recompiles]     - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 97167728)  # emb = self.linear(self.silu(emb))  # diffusers/src/diffusers/models/normalization.py:137 in forward

0/0 is for the first recompile. As we have been discussing, this is expected.

0/1 is the second recompile. And this is happening because length of _modules['linear'].scaling has increased. Is scaling supposed to change?

More information

It seems scaling is a dictionary. On the first recompilation, its length is 1 and key is default_0. On the second recompilation, there is a key addition of default_1. From the guards, it seems default_1 is not used because there is no guard on that value on that key. But given how Dynamo handles dicts, we still guard on the length of the dictionary.

First recompilation

 | +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
 | | +- DICT_LENGTH: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
 | | | +- EQUALS_MATCH: L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'] == 1.0  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward

Second recompilation - There is this default_1 key. It does not seem like its used in the model.

 | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
 | | +- DICT_LENGTH: len(L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling) == 2  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
 | | | +- EQUALS_MATCH: L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'] == 1.0  # scaling = self.scaling[active_adapter]  # peft/tuners/lora/layer.py:505 in forward
 | | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_1'], accessed_by=DictGetItemGuardAccessor(default_1)

Is it possible to keep the dictionary same on re-loading? I can investigate if we can just guard on the dict keys more lazily (i.e. only guard on only those keys/values that are used in the model, eliminating the length guard), but it seems little hard to do on surface.

sayakpaul commented 2 weeks ago

@anijain2305, thanks!

0/1 is the second recompile. And this is happening because length of _modules['linear'].scaling has increased. Is scaling supposed to change?

Yes, scaling can change depending on varying alpha and rank values associated with a given LoRA checkpoint.

Second recompilation - There is this default_1 key. It does not seem like its used in the model.

Well, the default_1 key in the name of the adapter. All the parameters associated to that adapter will have that key. You can verify this by doing:

...

pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
print(pipe.unet.state_dict())

Is it possible to keep the dictionary same on re-loading?

So, I am afraid this won't be possible because on a reload we're essentially adding things to the existing state dict of the modules affected by a given LoRA checkpoint.

Ccing @BenjaminBossan if I missed something.

BenjaminBossan commented 2 weeks ago

It seems scaling is a dictionary.

Out of curiosity, would this be different if scaling where a ModuleDict?

Is it possible to keep the dictionary same on re-loading?

I think there could be a way to replace the data from the first LoRA adapter directly with the second, instead of updating the dicts to add a separate adapter. To try this, I wanted to pass the same adapter_name when calling load_lora_weights but this runs into a guard here:

https://github.com/huggingface/diffusers/blob/4f495b06dcbbc3437a598a20718fe74c29308756/src/diffusers/loaders/lora_pipeline.py#L1721-L1724

After removing the guard, it appears like I could load the second adpater without increasing the size of the dicts. However, I'm not sure if this prevents recompilation.

BenjaminBossan commented 2 weeks ago

Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the (edit: compiled) FluxTransformer2DModel does not match in these lines:

https://github.com/huggingface/diffusers/blob/e417d028115e72b953a73e39d9687aa70ba3e37e/src/diffusers/loaders/lora_base.py#L371-L377

Therefore, it is kept as is. Maybe it would be better to check if hasattr(model, "unload_lora")?

However, I don't think that fixing this would solve the initial issue. If the LoRA layers are completely unloaded, it means they're removed and the second adapter will create completely new LoRA layers, which I guess would always trigger a recompilation. Maybe it's better to just offload the first adapter?

sayakpaul commented 2 weeks ago

Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the luxTransformer2DModel does not match in these lines:

Could you explain more? FluxTansformer2DModel is a subclass of PeftAdapterMixin: https://github.com/huggingface/diffusers/blob/e417d028115e72b953a73e39d9687aa70ba3e37e/src/diffusers/models/transformers/transformer_flux.py#L206

And PeftAdapterMixin has: https://github.com/huggingface/diffusers/blob/e417d028115e72b953a73e39d9687aa70ba3e37e/src/diffusers/loaders/peft.py#L305

BenjaminBossan commented 2 weeks ago

It appears that the type check fails when the model is compiled. When I run this code:

import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")

pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="foobar")
lora_ids_0 = {name: id(module) for name, module in pipe.transformer.named_modules()}
pipe.unload_lora_weights()

The LoRA weights of the pipe.transformer are not unloaded. Checking in the debugger, unloading is not called on the transformer part.

When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora") should still fix it. Alternatively, there needs to be a check for compiled models and then _orig_mod should be used. I could imagine that many more isinstance checks could be faulty with compiled models :-/

sayakpaul commented 1 week ago

Sorry missed it.

When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora") should still fix it.

Thanks for this info. But I don’t understand why do we still need the hasattr fix for the non-compiled model or am I missing something?

edit: you meant for compiled models.

Alternatively, there needs to be a check for compiled models and then _orig_mod should be used. I could imagine that many more isinstance checks could be faulty with compiled models :-/

Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?

BenjaminBossan commented 1 week ago

Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?

Not sure if something needs to be done on the PEFT side, or are you aware of something? If this is addressed on the diffusers side, we can try if anything has to change for PEFT too and fix it then.

sayakpaul commented 1 week ago

I was mainly referring to the isinstances used in both the codebases and the cases where we may have to include a check for compiled models too (or rejig the condition like the one you mentioned with hasattr).

BenjaminBossan commented 1 week ago

I think on the level of PEFT layers, which diffusers is using, we should be good. Maybe there are other parts of PEFT where isinstance checks could be invalid for compiled models, but those should not block this issue.

sayakpaul commented 1 week ago

Fair. I will keep this thread posted!

sayakpaul commented 4 days ago

I have been testing a few options.

My testbed is the lora-compile branch in diffusers. The major changes are around what @BenjaminBossan and I discussed in the comments above.

I am using this codebase:

Code ```python import torch from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe = pipe.to("cuda") pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead") prompt = "a photo of an astronaut riding a horse on mars" # This will compile for the first time image = pipe(prompt, num_inference_steps=5).images[0] pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="first") prompt = "a photo of an astronaut riding a horse on mars, tarot card" # This will re-compile image = pipe(prompt, num_inference_steps=5).images[0] # pipe.unload_lora_weights() pipe.set_lora_device(adapter_names=["first"], device="cpu") pipe.load_lora_weights("davisbro/half_illustration", adapter_name="second") prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars" # This will re-compile again image = pipe(prompt, num_inference_steps=5).images[0] # pipe.unload_lora_weights() pipe.set_lora_device(adapter_names=["second"], device="cpu") pipe.load_lora_weights("davisbro/half_illustration") prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars" # This will re-compile again image = pipe(prompt, num_inference_steps=5).images[0] ```

I have tried two options:

  1. Unloading the LoRA weights by calling pipe.unload_lora_weights(). I also checked that the LoRA weights were getting unloaded correctly. Logs.
  2. Moving the currently loaded LoRA to CPU by calling pipe.set_lora_device(...). Logs.

There are still recompiles. However, for the second option I see more recompiles than the former one.

@anijain2305 any comments?

Cc: @BenjaminBossan since we were discussing this.

anijain2305 commented 2 days ago

These are the guards and recompilation reasons - https://gist.github.com/anijain2305/9f3654e3a25b38446d572cfe2f9b7566

So, I think the recompiles can't be avoided easily because the codepath truly changes.

From torch.compile standpoint, after each loading, we are accessing a different key in the scaling dict. And in general, accessing a different key can lead to a totally different graph, so Dynamo guards on that key-value pair.

A small example that repros the above scenario is this


import torch

scaling = {}

def fn(x, key):
    return x * scaling[key]

opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = 1
opt_fn(x, "first")

scaling["second"] = 1
opt_fn(x, "second")

scaling["third"] = 1
opt_fn(x, "third")

torch.compile will guard on the key-value pair and cause a recompile everytime. Because the accessed key-value pair is different in each invocation of fn.

BenjaminBossan commented 22 hours ago

@anijain2305 I'm working on a hot swapping method. It's just a quick and dirty implementation for now but I would like to get some early feedback if it's worth putting more time into or not. For this, could you please check the trace that I received using TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt. It looks like there is no re-compilation.

traces.txt

anijain2305 commented 16 hours ago

@BenjaminBossan It looks good to me. As long as the scaling dictionary remains same, we should be good from recompilation issue.

image