facebookresearch / stable_signature

Official implementation of the paper "The Stable Signature Rooting Watermarks in Latent Diffusion Models"
Other
384 stars 48 forks source link

Message decoding problem using weight provided #29

Open LiRunyi2001 opened 2 months ago

LiRunyi2001 commented 2 months ago

Hi there! I have tried the weight of decoder you provided here: WM weights of latent decoder and I generate an image using code provided in README.md:

from utils_model import load_model_from_config 

ldm_config = "/gdata/cold1/lirunyi/model-watermark/v2-inference.yaml"
ldm_ckpt = "/gdata/cold1/lirunyi/model-watermark/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt"

print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')
from omegaconf import OmegaConf 
config = OmegaConf.load(f"{ldm_config}")
ldm_ae = load_model_from_config(config, ldm_ckpt)
ldm_aef = ldm_ae.first_stage_model
ldm_aef.eval()
state_dict = torch.load("sd2_decoder.pth")
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)
print(unexpected_keys)
print("you should check that the decoder keys are correctly matched")

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
pipeline = pipeline.to('cuda')
pipeline.vae.decode = (lambda x,  *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))
# run inference
images = []
prompt = "a cat and a dog"
img = pipeline(prompt).images[0]
img.save(f"./{prompt}.png")

Then I use this image trying to extract message in decoding.ipynb, however it turns out that it cannot be extracted correctly, and the bit accuracy is only about 50% to 60%. I am wondering is there anything wrong with my usage? Thanks a lot!

pierrefdz commented 2 months ago

Hi, can you share the logs? My guess is that the keys do not match between the latent decoder and the one of the diffusers codebase.

fenghe12 commented 1 month ago

hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.

pierrefdz commented 1 month ago

Hi, can you share the logs or code?

fenghe12 commented 1 month ago

sorry but i forgot to save training log,but i can share generation and decode code

fenghe12 commented 1 month ago

import torch device = torch.device("cuda")

from omegaconf import OmegaConf from diffusers import StableDiffusionPipeline from utils_model import load_model_from_config

ldm_config = "./stablediffusion/configs/stable-diffusion/v2-inference.yaml" ldm_ckpt = "./stablediffusion/checkpoints-base/v2-1_512-nonema-pruned.ckpt"

print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...') config = OmegaConf.load(f"{ldm_config}") ldm_ae = load_model_from_config(config, ldm_ckpt) ldm_aef = ldm_ae.first_stage_model ldm_aef.eval() state_dict = torch.load("./out_test_white_200/checkpoints_000.pth")["ldm_decoder"] unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False) print(unexpected_keys) print("you should check that the decoder keys are correctly matched") model = "stabilityai/stable-diffusion-2" pipe = StableDiffusionPipeline.frompretrained(model).to(device) prompts = [ "Professional picture of fishing kitten", ] * 50 import random seeds = [random.randint(0, 2**32 - 1) for in range(len(prompts))] pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))

for i, (prompt, seed) in enumerate(zip(prompts, seeds)): generator = torch.manual_seed(seed) img = pipe(prompt,generator=generator).images[0] img.save(f'./with_watermark_256/{i}.png')

fenghe12 commented 1 month ago

decoding code is exactly the same as decoding.ipynb

GoooHi commented 4 weeks ago

hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.

Sorry, I met the same problem. How did you solve it in the end? @fenghe12

pierrefdz commented 3 weeks ago

Sorry if I'm not super active... Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.

GoooHi commented 3 weeks ago

Sorry if I'm not super active... Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.

During the training phase, there seemed to be no errors, but the results I decoded during the testing phase were almost completely incorrect. @pierrefdz

git:sha: 8958dc7e82164e4a3b525d4f51e5df04e06fa9ff, status: has uncommited changes, branch: main log:{"train_dir": "./data/train", "val_dir": "./data/val", "ldm_config": "./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml", "ldm_ckpt": "./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt", "msg_decoder_path": "./models/dec_48b_whit.torchscript.pt", "num_bits": 48, "redundancy": 1, "decoder_depth": 8, "decoder_channels": 64, "batch_size": 4, "img_size": 256, "loss_i": "watson-vgg", "loss_w": "bce", "lambda_i": 0.2, "lambda_w": 1.0, "optimizer": "AdamW,lr=5e-4", "steps": 100, "warmup_steps": 20, "log_freq": 10, "save_img_freq": 1000, "num_keys": 1, "output_dir": "output/", "seed": 0, "debug": false}

Building LDM model with config ./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml and weights from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt... Loading model from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt Global Step: 220000 LatentDiffusion: Running in eps-prediction mode DiffusionWrapper has 865.91 M params. making attention of type 'vanilla' with 512 in_channels Working with z of shape (1, 4, 32, 32) = 4096 dimensions. making attention of type 'vanilla' with 512 in_channels Building hidden decoder with weights from ./models/dec_48b_whit.torchscript.pt... Loading data from ./data/train and ./data/val... Creating losses... Losses: bce and watson-vgg...

Creating key with 48 bits... Key: 111010110101000001010111010011010100010000100111 Training... {"iteration": 0, "loss": 0.7272339463233948, "loss_w": 0.7272318601608276, "loss_i": 1.0295131687598769e-05, "psnr": Infinity, "bit_acc_avg": 0.515625, "word_acc_avg": 0.0, "lr": 0.0} Train [ 0/100] eta: 0:02:30 iteration: 0.000000 (0.000000) loss: 0.727234 (0.727234) loss_w: 0.727232 (0.727232) loss_i: 0.000010 (0.000010) psnr: inf (inf) bit_acc_avg: 0.515625 (0.515625) word_acc_avg: 0.000000 (0.000000) lr: 0.000000 (0.000000) time: 1.500002 data: 0.438819 max mem: 10936 {"iteration": 10, "loss": 0.5056884288787842, "loss_w": 0.15685945749282837, "loss_i": 1.7441446781158447, "psnr": 34.489280700683594, "bit_acc_avg": 0.96875, "word_acc_avg": 0.0, "lr": 0.00025} Train [ 10/100] eta: 0:00:58 iteration: 5.000000 (5.000000) loss: 0.604755 (0.636774) loss_w: 0.462247 (0.489823) loss_i: 0.712543 (0.734757) psnr: 43.147713 (inf) bit_acc_avg: 0.817708 (0.754261) word_acc_avg: 0.000000 (0.000000) lr: 0.000125 (0.000125) time: 0.652794 data: 0.039989 max mem: 11664 {"iteration": 20, "loss": 0.5863916873931885, "loss_w": 0.0927756056189537, "loss_i": 2.468080520629883, "psnr": 30.310420989990234, "bit_acc_avg": 0.984375, "word_acc_avg": 0.75, "lr": 0.0005} Train [ 20/100] eta: 0:00:46 iteration: 10.000000 (10.000000) loss: 0.575908 (0.603986) loss_w: 0.228699 (0.323160) loss_i: 1.744145 (1.404130) psnr: 33.432957 (inf) bit_acc_avg: 0.932292 (0.858383) word_acc_avg: 0.000000 (0.202381) lr: 0.000250 (0.000250) time: 0.537721 data: 0.000102 max mem: 11664 {"iteration": 30, "loss": 0.5741378664970398, "loss_w": 0.0685807317495346, "loss_i": 2.527785539627075, "psnr": 30.195219039916992, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 0.00048100794336156604} Train [ 30/100] eta: 0:00:39 iteration: 20.000000 (15.000000) loss: 0.570326 (0.593232) loss_w: 0.090491 (0.242183) loss_i: 2.362381 (1.755245) psnr: 30.266314 (inf) bit_acc_avg: 0.989583 (0.901714) word_acc_avg: 0.750000 (0.379032) lr: 0.000481 (0.000328) time: 0.507474 data: 0.000098 max mem: 11664 {"iteration": 40, "loss": 0.5790627002716064, "loss_w": 0.018820516765117645, "loss_i": 2.801210880279541, "psnr": 27.312376022338867, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0004269231419060436} Train [ 40/100] eta: 0:00:32 iteration: 30.000000 (20.000000) loss: 0.574138 (0.592107) loss_w: 0.065551 (0.202381) loss_i: 2.500486 (1.948630) psnr: 29.710548 (inf) bit_acc_avg: 0.994792 (0.922891) word_acc_avg: 0.750000 (0.481707) lr: 0.000477 (0.000359) time: 0.508035 data: 0.000098 max mem: 11664 {"iteration": 50, "loss": 0.5642515420913696, "loss_w": 0.06884497404098511, "loss_i": 2.4770328998565674, "psnr": 29.902061462402344, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.00034597951637508993} Train [ 50/100] eta: 0:00:26 iteration: 40.000000 (25.000000) loss: 0.564252 (0.584697) loss_w: 0.060810 (0.172872) loss_i: 2.478716 (2.059127) psnr: 29.392599 (inf) bit_acc_avg: 0.994792 (0.937398) word_acc_avg: 0.750000 (0.553922) lr: 0.000420 (0.000364) time: 0.508685 data: 0.000103 max mem: 11664 {"iteration": 60, "loss": 0.5306546688079834, "loss_w": 0.036135606467723846, "loss_i": 2.47259521484375, "psnr": 29.56548500061035, "bit_acc_avg": 0.9895833730697632, "word_acc_avg": 0.75, "lr": 0.0002505} Train [ 60/100] eta: 0:00:21 iteration: 50.000000 (30.000000) loss: 0.550219 (0.578414) loss_w: 0.047642 (0.154611) loss_i: 2.477033 (2.119012) psnr: 29.838663 (inf) bit_acc_avg: 0.994792 (0.946380) word_acc_avg: 0.750000 (0.598361) lr: 0.000337 (0.000352) time: 0.509140 data: 0.000102 max mem: 11664 {"iteration": 70, "loss": 0.4845236539840698, "loss_w": 0.05966611206531525, "loss_i": 2.1242876052856445, "psnr": 31.549930572509766, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0001550204836249101} Train [ 70/100] eta: 0:00:15 iteration: 60.000000 (35.000000) loss: 0.525561 (0.570636) loss_w: 0.043449 (0.139244) loss_i: 2.408679 (2.156957) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.953859) word_acc_avg: 1.000000 (0.651408) lr: 0.000241 (0.000331) time: 0.508935 data: 0.000096 max mem: 11664 {"iteration": 80, "loss": 0.4934779107570648, "loss_w": 0.02895699068903923, "loss_i": 2.3226046562194824, "psnr": 29.44066047668457, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 7.40768580939564e-05} Train [ 80/100] eta: 0:00:10 iteration: 70.000000 (40.000000) loss: 0.493478 (0.559191) loss_w: 0.032891 (0.127513) loss_i: 2.172009 (2.158390) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.959105) word_acc_avg: 1.000000 (0.682099) lr: 0.000146 (0.000303) time: 0.508779 data: 0.000095 max mem: 11664 {"iteration": 90, "loss": 0.42957162857055664, "loss_w": 0.05145969241857529, "loss_i": 1.8905595541000366, "psnr": 31.834016799926758, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 1.9992056638433958e-05} Train [ 90/100] eta: 0:00:05 iteration: 80.000000 (45.000000) loss: 0.452601 (0.546342) loss_w: 0.041486 (0.118391) loss_i: 2.026620 (2.139756) psnr: 30.715746 (inf) bit_acc_avg: 1.000000 (0.963427) word_acc_avg: 1.000000 (0.708791) lr: 0.000067 (0.000274) time: 0.509235 data: 0.000096 max mem: 11664 Train [ 99/100] eta: 0:00:00 iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250) time: 0.509741 data: 0.000095 max mem: 11664 Train Total time: 0:00:52 (0.520741 s / it) Averaged train stats: iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250) torch.Size([16, 3, 256, 256]) Eval [0/7] eta: 0:00:23 iteration: 0.000000 (0.000000) psnr: 30.536995 (30.536995) bit_acc_none: 0.996094 (0.996094) word_acc_none: 0.812500 (0.812500) bit_acc_crop_01: 0.936198 (0.936198) word_acc_crop_01: 0.375000 (0.375000) bit_acc_crop_05: 0.990885 (0.990885) word_acc_crop_05: 0.812500 (0.812500) bit_acc_rot_25: 0.656250 (0.656250) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.483073 (0.483073) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.744792 (0.744792) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.993490 (0.993490) word_acc_resize_07: 0.812500 (0.812500) bit_acc_brightness_1p5: 0.994792 (0.994792) word_acc_brightness_1p5: 0.812500 (0.812500) bit_acc_brightness_2: 0.981771 (0.981771) word_acc_brightness_2: 0.375000 (0.375000) bit_acc_jpeg_80: 0.908854 (0.908854) word_acc_jpeg_80: 0.000000 (0.000000) bit_acc_jpeg_50: 0.861979 (0.861979) word_acc_jpeg_50: 0.000000 (0.000000) time: 3.353976 data: 0.500687 max mem: 11664

Eval [6/7] eta: 0:00:01 iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000) time: 1.366670 data: 0.071614 max mem: 11664 Eval Total time: 0:00:09 (1.219035 s / it) Averaged eval stats: iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000)

pierrefdz commented 3 weeks ago

And when you generate, what does the print(unexpected_keys) gives?

pierrefdz commented 3 weeks ago

And, could you try changing

pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))

into

pipe.vae.decode = (lambda x, *args, **kwargs: (
    print("Entering vae.decode"),
    ldm_aef.decode(x).unsqueeze(0)
)[-1])

to make sure that the new decoding is actually used.

PS: or alternatively, which is easier to read

def vae_decode(x, *args, **kwargs):
    print("Entering vae.decode")
    return ldm_aef.decode(x).unsqueeze(0)
pipe.vae.decode = vae_decode
GoooHi commented 3 weeks ago

And when you generate, what does the print(unexpected_keys) gives?

_IncompatibleKeys(missing_keys=['encoder.conv_in.weight', 'encoder.conv_in.bias', 'encoder.down.0.block.0.norm1.weight', 'encoder.down.0.block.0.norm1.bias', 'encoder.down.0.block.0.conv1.weight', 'encoder.down.0.block.0.conv1.bias', 'encoder.down.0.block.0.norm2.weight', 'encoder.down.0.block.0.norm2.bias', 'encoder.down.0.block.0.conv2.weight', 'encoder.down.0.block.0.conv2.bias', 'encoder.down.0.block.1.norm1.weight', 'encoder.down.0.block.1.norm1.bias', 'encoder.down.0.block.1.conv1.weight', 'encoder.down.0.block.1.conv1.bias', 'encoder.down.0.block.1.norm2.weight', 'encoder.down.0.block.1.norm2.bias', 'encoder.down.0.block.1.conv2.weight', 'encoder.down.0.block.1.conv2.bias', 'encoder.down.0.downsample.conv.weight', 'encoder.down.0.downsample.conv.bias', 'encoder.down.1.block.0.norm1.weight', 'encoder.down.1.block.0.norm1.bias', 'encoder.down.1.block.0.conv1.weight', 'encoder.down.1.block.0.conv1.bias', 'encoder.down.1.block.0.norm2.weight', 'encoder.down.1.block.0.norm2.bias', 'encoder.down.1.block.0.conv2.weight', 'encoder.down.1.block.0.conv2.bias', 'encoder.down.1.block.0.nin_shortcut.weight', 'encoder.down.1.block.0.nin_shortcut.bias', 'encoder.down.1.block.1.norm1.weight', 'encoder.down.1.block.1.norm1.bias', 'encoder.down.1.block.1.conv1.weight', 'encoder.down.1.block.1.conv1.bias', 'encoder.down.1.block.1.norm2.weight', 'encoder.down.1.block.1.norm2.bias', 'encoder.down.1.block.1.conv2.weight', 'encoder.down.1.block.1.conv2.bias', 'encoder.down.1.downsample.conv.weight', 'encoder.down.1.downsample.conv.bias', 'encoder.down.2.block.0.norm1.weight', 'encoder.down.2.block.0.norm1.bias', 'encoder.down.2.block.0.conv1.weight', 'encoder.down.2.block.0.conv1.bias', 'encoder.down.2.block.0.norm2.weight', 'encoder.down.2.block.0.norm2.bias', 'encoder.down.2.block.0.conv2.weight', 'encoder.down.2.block.0.conv2.bias', 'encoder.down.2.block.0.nin_shortcut.weight', 'encoder.down.2.block.0.nin_shortcut.bias', 'encoder.down.2.block.1.norm1.weight', 'encoder.down.2.block.1.norm1.bias', 'encoder.down.2.block.1.conv1.weight', 'encoder.down.2.block.1.conv1.bias', 'encoder.down.2.block.1.norm2.weight', 'encoder.down.2.block.1.norm2.bias', 'encoder.down.2.block.1.conv2.weight', 'encoder.down.2.block.1.conv2.bias', 'encoder.down.2.downsample.conv.weight', 'encoder.down.2.downsample.conv.bias', 'encoder.down.3.block.0.norm1.weight', 'encoder.down.3.block.0.norm1.bias', 'encoder.down.3.block.0.conv1.weight', 'encoder.down.3.block.0.conv1.bias', 'encoder.down.3.block.0.norm2.weight', 'encoder.down.3.block.0.norm2.bias', 'encoder.down.3.block.0.conv2.weight', 'encoder.down.3.block.0.conv2.bias', 'encoder.down.3.block.1.norm1.weight', 'encoder.down.3.block.1.norm1.bias', 'encoder.down.3.block.1.conv1.weight', 'encoder.down.3.block.1.conv1.bias', 'encoder.down.3.block.1.norm2.weight', 'encoder.down.3.block.1.norm2.bias', 'encoder.down.3.block.1.conv2.weight', 'encoder.down.3.block.1.conv2.bias', 'encoder.mid.block_1.norm1.weight', 'encoder.mid.block_1.norm1.bias', 'encoder.mid.block_1.conv1.weight', 'encoder.mid.block_1.conv1.bias', 'encoder.mid.block_1.norm2.weight', 'encoder.mid.block_1.norm2.bias', 'encoder.mid.block_1.conv2.weight', 'encoder.mid.block_1.conv2.bias', 'encoder.mid.attn_1.norm.weight', 'encoder.mid.attn_1.norm.bias', 'encoder.mid.attn_1.q.weight', 'encoder.mid.attn_1.q.bias', 'encoder.mid.attn_1.k.weight', 'encoder.mid.attn_1.k.bias', 'encoder.mid.attn_1.v.weight', 'encoder.mid.attn_1.v.bias', 'encoder.mid.attn_1.proj_out.weight', 'encoder.mid.attn_1.proj_out.bias', 'encoder.mid.block_2.norm1.weight', 'encoder.mid.block_2.norm1.bias', 'encoder.mid.block_2.conv1.weight', 'encoder.mid.block_2.conv1.bias', 'encoder.mid.block_2.norm2.weight', 'encoder.mid.block_2.norm2.bias', 'encoder.mid.block_2.conv2.weight', 'encoder.mid.block_2.conv2.bias', 'encoder.norm_out.weight', 'encoder.norm_out.bias', 'encoder.conv_out.weight', 'encoder.conv_out.bias', 'decoder.conv_in.weight', 'decoder.conv_in.bias', 'decoder.mid.block_1.norm1.weight', 'decoder.mid.block_1.norm1.bias', 'decoder.mid.block_1.conv1.weight', 'decoder.mid.block_1.conv1.bias', 'decoder.mid.block_1.norm2.weight', 'decoder.mid.block_1.norm2.bias', 'decoder.mid.block_1.conv2.weight', 'decoder.mid.block_1.conv2.bias', 'decoder.mid.attn_1.norm.weight', 'decoder.mid.attn_1.norm.bias', 'decoder.mid.attn_1.q.weight', 'decoder.mid.attn_1.q.bias', 'decoder.mid.attn_1.k.weight', 'decoder.mid.attn_1.k.bias', 'decoder.mid.attn_1.v.weight', 'decoder.mid.attn_1.v.bias', 'decoder.mid.attn_1.proj_out.weight', 'decoder.mid.attn_1.proj_out.bias', 'decoder.mid.block_2.norm1.weight', 'decoder.mid.block_2.norm1.bias', 'decoder.mid.block_2.conv1.weight', 'decoder.mid.block_2.conv1.bias', 'decoder.mid.block_2.norm2.weight', 'decoder.mid.block_2.norm2.bias', 'decoder.mid.block_2.conv2.weight', 'decoder.mid.block_2.conv2.bias', 'decoder.up.0.block.0.norm1.weight', 'decoder.up.0.block.0.norm1.bias', 'decoder.up.0.block.0.conv1.weight', 'decoder.up.0.block.0.conv1.bias', 'decoder.up.0.block.0.norm2.weight', 'decoder.up.0.block.0.norm2.bias', 'decoder.up.0.block.0.conv2.weight', 'decoder.up.0.block.0.conv2.bias', 'decoder.up.0.block.0.nin_shortcut.weight', 'decoder.up.0.block.0.nin_shortcut.bias', 'decoder.up.0.block.1.norm1.weight', 'decoder.up.0.block.1.norm1.bias', 'decoder.up.0.block.1.conv1.weight', 'decoder.up.0.block.1.conv1.bias', 'decoder.up.0.block.1.norm2.weight', 'decoder.up.0.block.1.norm2.bias', 'decoder.up.0.block.1.conv2.weight', 'decoder.up.0.block.1.conv2.bias', 'decoder.up.0.block.2.norm1.weight', 'decoder.up.0.block.2.norm1.bias', 'decoder.up.0.block.2.conv1.weight', 'decoder.up.0.block.2.conv1.bias', 'decoder.up.0.block.2.norm2.weight', 'decoder.up.0.block.2.norm2.bias', 'decoder.up.0.block.2.conv2.weight', 'decoder.up.0.block.2.conv2.bias', 'decoder.up.1.block.0.norm1.weight', 'decoder.up.1.block.0.norm1.bias', 'decoder.up.1.block.0.conv1.weight', 'decoder.up.1.block.0.conv1.bias', 'decoder.up.1.block.0.norm2.weight', 'decoder.up.1.block.0.norm2.bias', 'decoder.up.1.block.0.conv2.weight', 'decoder.up.1.block.0.conv2.bias', 'decoder.up.1.block.0.nin_shortcut.weight', 'decoder.up.1.block.0.nin_shortcut.bias', 'decoder.up.1.block.1.norm1.weight', 'decoder.up.1.block.1.norm1.bias', 'decoder.up.1.block.1.conv1.weight', 'decoder.up.1.block.1.conv1.bias', 'decoder.up.1.block.1.norm2.weight', 'decoder.up.1.block.1.norm2.bias', 'decoder.up.1.block.1.conv2.weight', 'decoder.up.1.block.1.conv2.bias', 'decoder.up.1.block.2.norm1.weight', 'decoder.up.1.block.2.norm1.bias', 'decoder.up.1.block.2.conv1.weight', 'decoder.up.1.block.2.conv1.bias', 'decoder.up.1.block.2.norm2.weight', 'decoder.up.1.block.2.norm2.bias', 'decoder.up.1.block.2.conv2.weight', 'decoder.up.1.block.2.conv2.bias', 'decoder.up.1.upsample.conv.weight', 'decoder.up.1.upsample.conv.bias', 'decoder.up.2.block.0.norm1.weight', 'decoder.up.2.block.0.norm1.bias', 'decoder.up.2.block.0.conv1.weight', 'decoder.up.2.block.0.conv1.bias', 'decoder.up.2.block.0.norm2.weight', 'decoder.up.2.block.0.norm2.bias', 'decoder.up.2.block.0.conv2.weight', 'decoder.up.2.block.0.conv2.bias', 'decoder.up.2.block.1.norm1.weight', 'decoder.up.2.block.1.norm1.bias', 'decoder.up.2.block.1.conv1.weight', 'decoder.up.2.block.1.conv1.bias', 'decoder.up.2.block.1.norm2.weight', 'decoder.up.2.block.1.norm2.bias', 'decoder.up.2.block.1.conv2.weight', 'decoder.up.2.block.1.conv2.bias', 'decoder.up.2.block.2.norm1.weight', 'decoder.up.2.block.2.norm1.bias', 'decoder.up.2.block.2.conv1.weight', 'decoder.up.2.block.2.conv1.bias', 'decoder.up.2.block.2.norm2.weight', 'decoder.up.2.block.2.norm2.bias', 'decoder.up.2.block.2.conv2.weight', 'decoder.up.2.block.2.conv2.bias', 'decoder.up.2.upsample.conv.weight', 'decoder.up.2.upsample.conv.bias', 'decoder.up.3.block.0.norm1.weight', 'decoder.up.3.block.0.norm1.bias', 'decoder.up.3.block.0.conv1.weight', 'decoder.up.3.block.0.conv1.bias', 'decoder.up.3.block.0.norm2.weight', 'decoder.up.3.block.0.norm2.bias', 'decoder.up.3.block.0.conv2.weight', 'decoder.up.3.block.0.conv2.bias', 'decoder.up.3.block.1.norm1.weight', 'decoder.up.3.block.1.norm1.bias', 'decoder.up.3.block.1.conv1.weight', 'decoder.up.3.block.1.conv1.bias', 'decoder.up.3.block.1.norm2.weight', 'decoder.up.3.block.1.norm2.bias', 'decoder.up.3.block.1.conv2.weight', 'decoder.up.3.block.1.conv2.bias', 'decoder.up.3.block.2.norm1.weight', 'decoder.up.3.block.2.norm1.bias', 'decoder.up.3.block.2.conv1.weight', 'decoder.up.3.block.2.conv1.bias', 'decoder.up.3.block.2.norm2.weight', 'decoder.up.3.block.2.norm2.bias', 'decoder.up.3.block.2.conv2.weight', 'decoder.up.3.block.2.conv2.bias', 'decoder.up.3.upsample.conv.weight', 'decoder.up.3.upsample.conv.bias', 'decoder.norm_out.weight', 'decoder.norm_out.bias', 'decoder.conv_out.weight', 'decoder.conv_out.bias', 'quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias'], unexpected_keys=['ldm_decoder', 'optimizer', 'params']) you should check that the decoder keys are correctly matched

pierrefdz commented 3 weeks ago

So in your case it's (1). You need to make sure that no "decoder.*" keys are printed. The unexpected keys are the keys in your state dict. So you should do unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]

See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase

GoooHi commented 3 weeks ago

So in your case it's (1). You need to make sure that no "decoder.*" keys are printed. The unexpected keys are the keys in your state dict. So you should do unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]

See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase

Thank you very much for your help! I solved my problem. Maybe you can change the code in Generate with Diffusers. state_dict = torch.load("sd2_decoder.pth") -> state_dict = torch.load("sd2_decoder.pth")['ldm_decoder']

pierrefdz commented 3 weeks ago

If you do it with the weights I provided, there is no need to do it since the state_dict is the ldm_decoder directly

pierrefdz commented 3 weeks ago

I'll add something in the readme! Thx for the follow-up on this!

LiRunyi2001 commented 2 weeks ago

Sorry for the late reply :( I have also tried out the solution and solved the problem. Thank you all for your solutions and follow-up!