huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
24.15k stars 4.98k forks source link

DoRA loading does not load all keys from the state_dict #7592

Open RyanJDick opened 3 months ago

RyanJDick commented 3 months ago

Describe the bug

When loading a DoRA model from a kohya state_dict some keys are silently skipped in the state_dict.

DoRA loading was added in https://github.com/huggingface/diffusers/pull/7371. This feature has not been released yet, so I am encountering this issue on main (commit: 6133d98ff70eafad7b9f65da50a450a965d1957f)

Reproduction

In this script, I try to load the same test DoRA that was used in the original DoRA PR (https://github.com/huggingface/diffusers/pull/7371).

from diffusers import DiffusionPipeline
from safetensors.torch import load_file

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    variant="fp16",
).to("cuda")
pipe.load_lora_weights("streamize/test-dora")

# Load state_dict directly from local path.
state_dict = load_file(
    "/home/ryan/.cache/huggingface/hub/models--streamize--test-dora/snapshots/2c73f1cccb75b19c0b597f7ebadb10624966cd3f/pytorch_lora_weights.safetensors"
)

key = "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale"
print(f"State dict value at key: {key}")
print("-----")
val = state_dict[key]
print(f"val.shape: {val.shape}")
print(f"val[0, :5]: {val[0, :5]}")

print(f"\nModel tensor at key: {key}")
print("-----")
val = pipe.text_encoder.text_model.encoder.layers[0].mlp.fc1.lora_magnitude_vector["default_0"]
print(f"val.shape: {val.shape}")
print(f"val[:5]: {val[:5]}")

Output:

State dict value at key: lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale
-----
val.shape: torch.Size([1, 768])
val[0, :5]: tensor([ 0.0029, -0.0030,  0.0007, -0.0010, -0.0026], dtype=torch.float16)

Model tensor at key: lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale
-----
val.shape: torch.Size([3072])
val[:5]: tensor([0.4485, 0.4538, 0.4752, 0.4901, 0.4194], device='cuda:0',
       grad_fn=<SliceBackward0>)

I am using "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale" as an example, but the same behaviour is observed for many keys. The state_dict value does not get injected into the model. In fact, it's shape is not even compatible with the target tensor where I'd expect it to be injected.

From the digging I have done so far, I currently suspect 2 issues:

To understand the problem better, I recommend setting a breakpoint here: https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/utils/save_and_load.py#L249 Inspecting the state before and after load_state_dict() makes it easy to see which state_dict keys that diffusers is trying to inject, and which ones are not being applied.

Beware of this issue with load_state_dict()'s handling of unexpected_keys: https://github.com/pytorch/pytorch/issues/123510. This threw me off when I was debugging.

Logs

No response

System Info

Testing with diffusers commit: 6133d98ff70eafad7b9f65da50a450a965d1957f

Who can help?

@sayakpaul

sayakpaul commented 3 months ago

Cc: @BenjaminBossan here since it seems like the problem is stemming from the state dict injection step.

Inspecting the state before and after load_state_dict() makes it easy to see which state_dict keys that diffusers is trying to inject, and which ones are not being applied.

Which load_state_dict() function are you referring to here?

RyanJDick commented 3 months ago

Which load_state_dict() function are you referring to here?

This one:

To understand the problem better, I recommend setting a breakpoint here: https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/utils/save_and_load.py#L249

sayakpaul commented 3 months ago

@yiyixuxu as per understanding, the issue has roots in peft. Hence I cc'd @BenjaminBossan.

RyanJDick commented 3 months ago

@yiyixuxu as per understanding, the issue has roots in peft. Hence I cc'd @BenjaminBossan.

I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.

sayakpaul commented 3 months ago

I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.

Do not exist where? Could you give an example?

RyanJDick commented 3 months ago

I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.

Do not exist where? Could you give an example?

Using the example from the reproduction script, "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale" from the kohya state_dict gets converted by diffusers to "text_model.encoder.layers.0.mlp.fc1.lora_magnitude_vector.default_0.down.weight". There is no such module in the peft model, so it silently gets skipped.

I am just using this key as an example, but the same is true for many keys. I have not gone to the effort of checking them all.

sayakpaul commented 3 months ago

I see many occurrences of lora_magnitud_evector here: https://github.com/search?q=repo%3Ahuggingface%2Fpeft%20lora_magnitude_vector&type=code. Perhaps @BenjaminBossan could help clarify this.

I am just using this key as an example, but the same is true for many keys. I have not gone to the effort of checking them all.

Unique examples will be appreciated.

BenjaminBossan commented 3 months ago

It's hard for me to understand what is going on here.

From the PEFT side of things, we don't really do anything special with the DoRA parameters, so treating them in the same fashion as the other LoRA parameters should be correct. What's making this difficult is that the adapters were trained with another LoRA/DoRA implementation (LyCORIS I assume), not with PEFT, so they could have some differences there that make it difficult to load their weights onto a PEFT model. We don't have any control over that, we don't even know if this is stable over time (short of tracking all the code changes there).

To get to the bottom of this, we would need to understand what differentiates this checkpoint from the previous ones that seemed to work correctly with #7371 added.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.