Open apolinario opened 2 weeks ago
Some references:
_orig_mod
prefix to all keys
odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight', '_orig_mod.time_text_embed.timestep_embedder.linear_1.bias'...
@apolinario thanks for the detailed thread. Would be also nice to include a small snippet that we could quickly verify. And yes, ccing @BenjaminBossan for his comments here.
I assume we could just add checks in load_lora_weights()
just before when we pass the adapter weights to the underlying model:
https://github.com/huggingface/diffusers/blob/4cfb2164fb05d54dd594373b4bd1fbb101fef70c/src/diffusers/loaders/lora_pipeline.py#L1218
I have not dealt with hot-swapping LoRA weights on compiled models in PEFT, so I'm not surprised that it doesn't work out of the box. The PEFT experiments with compiled models all apply the compilation step after loading the PEFT weights. Maybe it's as easy as remapping the state dict and then calling set_peft_model_state_dict
, but I wouldn't be surprised if there are more pitfalls related to torch.compile
implementation details.
I'll make a note to look into this on the PEFT side when I have a bit of extra time. But do let me know if you make any progress.
Would be also nice to include a small snippet that we could quickly verify
Added!
Applying this change seems to work: https://github.com/huggingface/diffusers/commit/0e7204abcdd1418645e85c541c15da0dfbbbd410
However, I get the following when doing torch.compile()
with mode="max-autotune", fullgraph=True
:
skipping cudagraphs due to skipping cudagraphs due to cpu device (arg25_1). Found from :
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 494, in forward
encoder_hidden_states, hidden_states = block(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 165, in forward
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 137, in forward
emb = self.linear(self.silu(emb))
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/peft/tuners/lora/layer.py", line 509, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
Cool, https://github.com/huggingface/diffusers/commit/0e7204abcdd1418645e85c541c15da0dfbbbd410 works here too for reduce-overhead
However it triggers a recompilation when you load/swap a LoRA:
import torch
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
prompt = "a photo of an astronaut riding a horse on mars"
#This will compile for the first time
image = pipe(prompt).images[0]
pipe.load_lora_weights("multimodalart/flux-tarot-v1")
prompt = "a photo of an astronaut riding a horse on mars, tarot card"
#This will re-compile
image = pipe(prompt).images[0]
pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
prompt = "In the style of TOK, a photo of an astronaut riding a horse on mars"
#This will re-compile again
image = pipe(prompt).images[0]
Oh okay. Let's perhaps discuss this with the PyTorch team.
Recompiles are coming from this
V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles] triggered by the following guard failure(s):
V0827 10:44:01.036000 3046403 torch/_dynamo/guards.py:2796] [0/1] [__recompiles] - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 81088688) # emb = self.linear(self.silu(emb)) # diffusers/src/diffusers/models/normalization.py:137 in forward
On the first compile, TorchDynamo guarded on the id(type(torch.nn.Linear))
. But loading weights has somehow changed the type of the linear
layer (or atleast the id of the type of class). Do you know if loading weights of the linear changes its class type (of if we are dynamically creating a new class)?
Yeah it does.
When we're hitting https://github.com/huggingface/diffusers/blob/4cfb2164fb05d54dd594373b4bd1fbb101fef70c/src/diffusers/loaders/lora_pipeline.py#L1217
The codepath in peft
is
To get around the problem, we can likely just fuse the adapter modules into the respective base model modules to prevent recompilation but this is somewhat a bit restrictive in terms of a smoother UX. Perhaps @apolinario can explain this a bit better. So, we wanted to see if there's any alternative we could try here.
Cc: @BenjaminBossan
No easy alternative I can think of as of now. torch.compile
is adding a guard on the type. In this case, it seems that the type change is benign, but in general the type change requires a recompilation.
Is it possible to somehow ensure that type of the class is same as before after all the weights have been loaded? If we can do that, there wont be any recompiles. A bad way would be to monkeypatch the __class__
itself to the old type.
If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.
The other thing is that if you call model.compile()
(as opposed to torch.compile(model)
), the state dict won't be modified.
If you're loading a Lora, it's very reasonable to have a recompilation, no? The actual operations are different. The question here is whether it needs a recompilation upon swapping a new LoRA.
Agreed. I think loading the first LoRA triggering recompilation is probably fine. However, hot swapping different LoRAs and having it not recompile would be very good to allow for dynamic apps that change LoRAs to benefit from compilation performances (as waiting for it to compile every time LoRAs are swapped would not allow for a live swapping application)
I see. In that case, can someone run TORCH_LOGS="guards,recompiles" and share the log? For some reason, my run fails at
pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
I am getting you the log but recompilation upon swapping or loading a new LoRA also seems reasonable to me:
Both of these combined should lead to different operations, I'd imagine.
@anijain2305 here you go: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/regular_lora_compile.txt
I ran @apolinario's code here from the lora-compile
branch of diffusers. I am on torch 2.5.0.dev20240827+cu121
. nvidia-smi
:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:97:00.0 Off | 0 |
| N/A 36C P0 68W / 700W | 2MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
LMK if you need more information.
Note that when PEFT loads the 2nd LoRA adapter, no layers should be swapped. Instead, PEFT will detect that LoRA layers are already in place and will instead update the nn.ModuleDict
of those LoRA layers to contain the newly loaded weights. However, the decomposed A and B LoRA weights are implemented as nn.Linear
layers, so each newly loaded LoRA adds new layers to the nn.ModuleDict
, not just new weights. I assume that this is what trips the guard.
If this is indeed the cause, the whole loading procedure probably needs to be rewritten to directly overwrite the weight.data
of the first LoRA weight, without going through PEFT, in order to avoid what I just mentioned.
Thanks @sayakpaul
These are the recompilation reasons
[0/2] [__recompiles] triggered by the following guard failure(s):
[0/2] [__recompiles] - 0/1: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1 # scaling = self.scaling[active_adapter] # peft/tuners/lora/layer.py:505 in forward
[0/2] [__recompiles] - 0/0: ___check_type_id(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'], 97167728) # emb = self.linear(self.silu(emb)) # diffusers/src/diffusers/models/normalization.py:137 in forward
0/0
is for the first recompile. As we have been discussing, this is expected.
0/1
is the second recompile. And this is happening because length of _modules['linear'].scaling
has increased. Is scaling
supposed to change?
It seems scaling
is a dictionary. On the first recompilation, its length is 1 and key is default_0
. On the second recompilation, there is a key addition of default_1
. From the guards, it seems default_1
is not used because there is no guard on that value on that key. But given how Dynamo handles dicts, we still guard on the length of the dictionary.
First recompilation
| +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
| | +- DICT_LENGTH: len(L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling) == 1 # scaling = self.scaling[active_adapter] # peft/tuners/lora/layer.py:505 in forward
| | +- GuardManager: source=L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
| | | +- EQUALS_MATCH: L['self']._modules['transformer_blocks']._modules['0']._modules['norm1']._modules['linear'].scaling['default_0'] == 1.0 # scaling = self.scaling[active_adapter] # peft/tuners/lora/layer.py:505 in forward
Second recompilation - There is this default_1
key. It does not seem like its used in the model.
| +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling, accessed_by=DictGetItemGuardAccessor(scaling)
| | +- DICT_LENGTH: len(L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling) == 2 # scaling = self.scaling[active_adapter] # peft/tuners/lora/layer.py:505 in forward
| | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'], accessed_by=DictGetItemGuardAccessor(default_0)
| | | +- EQUALS_MATCH: L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_0'] == 1.0 # scaling = self.scaling[active_adapter] # peft/tuners/lora/layer.py:505 in forward
| | +- GuardManager: source=L['self']._modules['single_transformer_blocks']._modules['37']._modules['norm']._modules['linear'].scaling['default_1'], accessed_by=DictGetItemGuardAccessor(default_1)
Is it possible to keep the dictionary same on re-loading? I can investigate if we can just guard on the dict keys more lazily (i.e. only guard on only those keys/values that are used in the model, eliminating the length guard), but it seems little hard to do on surface.
@anijain2305, thanks!
0/1 is the second recompile. And this is happening because length of _modules['linear'].scaling has increased. Is scaling supposed to change?
Yes, scaling can change depending on varying alpha
and rank
values associated with a given LoRA checkpoint.
Second recompilation - There is this default_1 key. It does not seem like its used in the model.
Well, the default_1
key in the name of the adapter. All the parameters associated to that adapter will have that key. You can verify this by doing:
...
pipe.unload_lora_weights()
pipe.load_lora_weights("davisbro/half_illustration")
print(pipe.unet.state_dict())
Is it possible to keep the dictionary same on re-loading?
So, I am afraid this won't be possible because on a reload we're essentially adding things to the existing state dict of the modules affected by a given LoRA checkpoint.
Ccing @BenjaminBossan if I missed something.
It seems
scaling
is a dictionary.
Out of curiosity, would this be different if scaling
where a ModuleDict
?
Is it possible to keep the dictionary same on re-loading?
I think there could be a way to replace the data from the first LoRA adapter directly with the second, instead of updating the dicts to add a separate adapter. To try this, I wanted to pass the same adapter_name
when calling load_lora_weights
but this runs into a guard here:
After removing the guard, it appears like I could load the second adpater without increasing the size of the dicts. However, I'm not sure if this prevents recompilation.
Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the (edit: compiled) FluxTransformer2DModel
does not match in these lines:
Therefore, it is kept as is. Maybe it would be better to check if hasattr(model, "unload_lora")
?
However, I don't think that fixing this would solve the initial issue. If the LoRA layers are completely unloaded, it means they're removed and the second adapter will create completely new LoRA layers, which I guess would always trigger a recompilation. Maybe it's better to just offload the first adapter?
Looking more into this, I think that unloading for Flux models does not work correctly. Specifically, the luxTransformer2DModel does not match in these lines:
Could you explain more? FluxTansformer2DModel
is a subclass of PeftAdapterMixin
:
https://github.com/huggingface/diffusers/blob/e417d028115e72b953a73e39d9687aa70ba3e37e/src/diffusers/models/transformers/transformer_flux.py#L206
And PeftAdapterMixin
has:
https://github.com/huggingface/diffusers/blob/e417d028115e72b953a73e39d9687aa70ba3e37e/src/diffusers/loaders/peft.py#L305
It appears that the type check fails when the model is compiled. When I run this code:
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
pipe.load_lora_weights("multimodalart/flux-tarot-v1", adapter_name="foobar")
lora_ids_0 = {name: id(module) for name, module in pipe.transformer.named_modules()}
pipe.unload_lora_weights()
The LoRA weights of the pipe.transformer
are not unloaded. Checking in the debugger, unloading is not called on the transformer
part.
When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora")
should still fix it. Alternatively, there needs to be a check for compiled models and then _orig_mod
should be used. I could imagine that many more isinstance
checks could be faulty with compiled models :-/
Sorry missed it.
When I remove the compilation step, the unloading works. Using if hasattr(model, "unload_lora") should still fix it.
Thanks for this info. But I don’t understand why do we still need the hasattr
fix for the non-compiled model or am I missing something?
edit: you meant for compiled models.
Alternatively, there needs to be a check for compiled models and then _orig_mod should be used. I could imagine that many more isinstance checks could be faulty with compiled models :-/
Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?
Agreed. I can follow code trails in diffusers. Would you be able to check it for peft?
Not sure if something needs to be done on the PEFT side, or are you aware of something? If this is addressed on the diffusers side, we can try if anything has to change for PEFT too and fix it then.
I was mainly referring to the isinstances used in both the codebases and the cases where we may have to include a check for compiled models too (or rejig the condition like the one you mentioned with hasattr).
I think on the level of PEFT layers, which diffusers is using, we should be good. Maybe there are other parts of PEFT where isinstance
checks could be invalid for compiled models, but those should not block this issue.
Fair. I will keep this thread posted!
I have been testing a few options.
My testbed is the lora-compile
branch in diffusers
. The major changes are around what @BenjaminBossan and I discussed in the comments above.
I am using this codebase:
I have tried two options:
pipe.unload_lora_weights()
. I also checked that the LoRA weights were getting unloaded correctly. Logs. pipe.set_lora_device(...)
. Logs. There are still recompiles. However, for the second option I see more recompiles than the former one.
@anijain2305 any comments?
Cc: @BenjaminBossan since we were discussing this.
These are the guards and recompilation reasons - https://gist.github.com/anijain2305/9f3654e3a25b38446d572cfe2f9b7566
So, I think the recompiles can't be avoided easily because the codepath truly changes.
From torch.compile
standpoint, after each loading, we are accessing a different key in the scaling
dict. And in general, accessing a different key can lead to a totally different graph, so Dynamo guards on that key-value pair.
A small example that repros the above scenario is this
import torch
scaling = {}
def fn(x, key):
return x * scaling[key]
opt_fn = torch.compile(fn, backend="eager")
x = torch.rand(4)
scaling["first"] = 1
opt_fn(x, "first")
scaling["second"] = 1
opt_fn(x, "second")
scaling["third"] = 1
opt_fn(x, "third")
torch.compile
will guard on the key-value pair and cause a recompile everytime. Because the accessed key-value pair is different in each invocation of fn.
@anijain2305 I'm working on a hot swapping method. It's just a quick and dirty implementation for now but I would like to get some early feedback if it's worth putting more time into or not. For this, could you please check the trace that I received using TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt
. It looks like there is no re-compilation.
@BenjaminBossan It looks good to me. As long as the scaling dictionary remains same, we should be good from recompilation issue.
Is your feature request related to a problem? Please describe. Would be great to be able to load a LoRA to a model compiled with
torch.compile
Describe the solution you'd like. Do
load_lora_weights
with a compiledpipe
(ideally without triggering recompilation)Currently, running this code:
It errors:
When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys
Describe alternatives you've considered. An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)
Additional context. This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)