Open LiRunyi2001 opened 2 months ago
Hi, can you share the logs? My guess is that the keys do not match between the latent decoder and the one of the diffusers codebase.
hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.
Hi, can you share the logs or code?
sorry but i forgot to save training log,but i can share generation and decode code
import torch device = torch.device("cuda")
from omegaconf import OmegaConf from diffusers import StableDiffusionPipeline from utils_model import load_model_from_config
ldm_config = "./stablediffusion/configs/stable-diffusion/v2-inference.yaml" ldm_ckpt = "./stablediffusion/checkpoints-base/v2-1_512-nonema-pruned.ckpt"
print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...') config = OmegaConf.load(f"{ldm_config}") ldm_ae = load_model_from_config(config, ldm_ckpt) ldm_aef = ldm_ae.first_stage_model ldm_aef.eval() state_dict = torch.load("./out_test_white_200/checkpoints_000.pth")["ldm_decoder"] unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False) print(unexpected_keys) print("you should check that the decoder keys are correctly matched") model = "stabilityai/stable-diffusion-2" pipe = StableDiffusionPipeline.frompretrained(model).to(device) prompts = [ "Professional picture of fishing kitten", ] * 50 import random seeds = [random.randint(0, 2**32 - 1) for in range(len(prompts))] pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))
for i, (prompt, seed) in enumerate(zip(prompts, seeds)): generator = torch.manual_seed(seed) img = pipe(prompt,generator=generator).images[0] img.save(f'./with_watermark_256/{i}.png')
decoding code is exactly the same as decoding.ipynb
hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.
Sorry, I met the same problem. How did you solve it in the end? @fenghe12
Sorry if I'm not super active... Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.
Sorry if I'm not super active... Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.
During the training phase, there seemed to be no errors, but the results I decoded during the testing phase were almost completely incorrect. @pierrefdz
git:sha: 8958dc7e82164e4a3b525d4f51e5df04e06fa9ff, status: has uncommited changes, branch: main log:{"train_dir": "./data/train", "val_dir": "./data/val", "ldm_config": "./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml", "ldm_ckpt": "./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt", "msg_decoder_path": "./models/dec_48b_whit.torchscript.pt", "num_bits": 48, "redundancy": 1, "decoder_depth": 8, "decoder_channels": 64, "batch_size": 4, "img_size": 256, "loss_i": "watson-vgg", "loss_w": "bce", "lambda_i": 0.2, "lambda_w": 1.0, "optimizer": "AdamW,lr=5e-4", "steps": 100, "warmup_steps": 20, "log_freq": 10, "save_img_freq": 1000, "num_keys": 1, "output_dir": "output/", "seed": 0, "debug": false}
Building LDM model with config ./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml and weights from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt... Loading model from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt Global Step: 220000 LatentDiffusion: Running in eps-prediction mode DiffusionWrapper has 865.91 M params. making attention of type 'vanilla' with 512 in_channels Working with z of shape (1, 4, 32, 32) = 4096 dimensions. making attention of type 'vanilla' with 512 in_channels Building hidden decoder with weights from ./models/dec_48b_whit.torchscript.pt... Loading data from ./data/train and ./data/val... Creating losses... Losses: bce and watson-vgg...
Creating key with 48 bits... Key: 111010110101000001010111010011010100010000100111 Training... {"iteration": 0, "loss": 0.7272339463233948, "loss_w": 0.7272318601608276, "loss_i": 1.0295131687598769e-05, "psnr": Infinity, "bit_acc_avg": 0.515625, "word_acc_avg": 0.0, "lr": 0.0} Train [ 0/100] eta: 0:02:30 iteration: 0.000000 (0.000000) loss: 0.727234 (0.727234) loss_w: 0.727232 (0.727232) loss_i: 0.000010 (0.000010) psnr: inf (inf) bit_acc_avg: 0.515625 (0.515625) word_acc_avg: 0.000000 (0.000000) lr: 0.000000 (0.000000) time: 1.500002 data: 0.438819 max mem: 10936 {"iteration": 10, "loss": 0.5056884288787842, "loss_w": 0.15685945749282837, "loss_i": 1.7441446781158447, "psnr": 34.489280700683594, "bit_acc_avg": 0.96875, "word_acc_avg": 0.0, "lr": 0.00025} Train [ 10/100] eta: 0:00:58 iteration: 5.000000 (5.000000) loss: 0.604755 (0.636774) loss_w: 0.462247 (0.489823) loss_i: 0.712543 (0.734757) psnr: 43.147713 (inf) bit_acc_avg: 0.817708 (0.754261) word_acc_avg: 0.000000 (0.000000) lr: 0.000125 (0.000125) time: 0.652794 data: 0.039989 max mem: 11664 {"iteration": 20, "loss": 0.5863916873931885, "loss_w": 0.0927756056189537, "loss_i": 2.468080520629883, "psnr": 30.310420989990234, "bit_acc_avg": 0.984375, "word_acc_avg": 0.75, "lr": 0.0005} Train [ 20/100] eta: 0:00:46 iteration: 10.000000 (10.000000) loss: 0.575908 (0.603986) loss_w: 0.228699 (0.323160) loss_i: 1.744145 (1.404130) psnr: 33.432957 (inf) bit_acc_avg: 0.932292 (0.858383) word_acc_avg: 0.000000 (0.202381) lr: 0.000250 (0.000250) time: 0.537721 data: 0.000102 max mem: 11664 {"iteration": 30, "loss": 0.5741378664970398, "loss_w": 0.0685807317495346, "loss_i": 2.527785539627075, "psnr": 30.195219039916992, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 0.00048100794336156604} Train [ 30/100] eta: 0:00:39 iteration: 20.000000 (15.000000) loss: 0.570326 (0.593232) loss_w: 0.090491 (0.242183) loss_i: 2.362381 (1.755245) psnr: 30.266314 (inf) bit_acc_avg: 0.989583 (0.901714) word_acc_avg: 0.750000 (0.379032) lr: 0.000481 (0.000328) time: 0.507474 data: 0.000098 max mem: 11664 {"iteration": 40, "loss": 0.5790627002716064, "loss_w": 0.018820516765117645, "loss_i": 2.801210880279541, "psnr": 27.312376022338867, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0004269231419060436} Train [ 40/100] eta: 0:00:32 iteration: 30.000000 (20.000000) loss: 0.574138 (0.592107) loss_w: 0.065551 (0.202381) loss_i: 2.500486 (1.948630) psnr: 29.710548 (inf) bit_acc_avg: 0.994792 (0.922891) word_acc_avg: 0.750000 (0.481707) lr: 0.000477 (0.000359) time: 0.508035 data: 0.000098 max mem: 11664 {"iteration": 50, "loss": 0.5642515420913696, "loss_w": 0.06884497404098511, "loss_i": 2.4770328998565674, "psnr": 29.902061462402344, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.00034597951637508993} Train [ 50/100] eta: 0:00:26 iteration: 40.000000 (25.000000) loss: 0.564252 (0.584697) loss_w: 0.060810 (0.172872) loss_i: 2.478716 (2.059127) psnr: 29.392599 (inf) bit_acc_avg: 0.994792 (0.937398) word_acc_avg: 0.750000 (0.553922) lr: 0.000420 (0.000364) time: 0.508685 data: 0.000103 max mem: 11664 {"iteration": 60, "loss": 0.5306546688079834, "loss_w": 0.036135606467723846, "loss_i": 2.47259521484375, "psnr": 29.56548500061035, "bit_acc_avg": 0.9895833730697632, "word_acc_avg": 0.75, "lr": 0.0002505} Train [ 60/100] eta: 0:00:21 iteration: 50.000000 (30.000000) loss: 0.550219 (0.578414) loss_w: 0.047642 (0.154611) loss_i: 2.477033 (2.119012) psnr: 29.838663 (inf) bit_acc_avg: 0.994792 (0.946380) word_acc_avg: 0.750000 (0.598361) lr: 0.000337 (0.000352) time: 0.509140 data: 0.000102 max mem: 11664 {"iteration": 70, "loss": 0.4845236539840698, "loss_w": 0.05966611206531525, "loss_i": 2.1242876052856445, "psnr": 31.549930572509766, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0001550204836249101} Train [ 70/100] eta: 0:00:15 iteration: 60.000000 (35.000000) loss: 0.525561 (0.570636) loss_w: 0.043449 (0.139244) loss_i: 2.408679 (2.156957) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.953859) word_acc_avg: 1.000000 (0.651408) lr: 0.000241 (0.000331) time: 0.508935 data: 0.000096 max mem: 11664 {"iteration": 80, "loss": 0.4934779107570648, "loss_w": 0.02895699068903923, "loss_i": 2.3226046562194824, "psnr": 29.44066047668457, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 7.40768580939564e-05} Train [ 80/100] eta: 0:00:10 iteration: 70.000000 (40.000000) loss: 0.493478 (0.559191) loss_w: 0.032891 (0.127513) loss_i: 2.172009 (2.158390) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.959105) word_acc_avg: 1.000000 (0.682099) lr: 0.000146 (0.000303) time: 0.508779 data: 0.000095 max mem: 11664 {"iteration": 90, "loss": 0.42957162857055664, "loss_w": 0.05145969241857529, "loss_i": 1.8905595541000366, "psnr": 31.834016799926758, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 1.9992056638433958e-05} Train [ 90/100] eta: 0:00:05 iteration: 80.000000 (45.000000) loss: 0.452601 (0.546342) loss_w: 0.041486 (0.118391) loss_i: 2.026620 (2.139756) psnr: 30.715746 (inf) bit_acc_avg: 1.000000 (0.963427) word_acc_avg: 1.000000 (0.708791) lr: 0.000067 (0.000274) time: 0.509235 data: 0.000096 max mem: 11664 Train [ 99/100] eta: 0:00:00 iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250) time: 0.509741 data: 0.000095 max mem: 11664 Train Total time: 0:00:52 (0.520741 s / it) Averaged train stats: iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250) torch.Size([16, 3, 256, 256]) Eval [0/7] eta: 0:00:23 iteration: 0.000000 (0.000000) psnr: 30.536995 (30.536995) bit_acc_none: 0.996094 (0.996094) word_acc_none: 0.812500 (0.812500) bit_acc_crop_01: 0.936198 (0.936198) word_acc_crop_01: 0.375000 (0.375000) bit_acc_crop_05: 0.990885 (0.990885) word_acc_crop_05: 0.812500 (0.812500) bit_acc_rot_25: 0.656250 (0.656250) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.483073 (0.483073) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.744792 (0.744792) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.993490 (0.993490) word_acc_resize_07: 0.812500 (0.812500) bit_acc_brightness_1p5: 0.994792 (0.994792) word_acc_brightness_1p5: 0.812500 (0.812500) bit_acc_brightness_2: 0.981771 (0.981771) word_acc_brightness_2: 0.375000 (0.375000) bit_acc_jpeg_80: 0.908854 (0.908854) word_acc_jpeg_80: 0.000000 (0.000000) bit_acc_jpeg_50: 0.861979 (0.861979) word_acc_jpeg_50: 0.000000 (0.000000) time: 3.353976 data: 0.500687 max mem: 11664
Eval [6/7] eta: 0:00:01 iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000) time: 1.366670 data: 0.071614 max mem: 11664 Eval Total time: 0:00:09 (1.219035 s / it) Averaged eval stats: iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000)
And when you generate, what does the print(unexpected_keys)
gives?
And, could you try changing
pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))
into
pipe.vae.decode = (lambda x, *args, **kwargs: (
print("Entering vae.decode"),
ldm_aef.decode(x).unsqueeze(0)
)[-1])
to make sure that the new decoding is actually used.
PS: or alternatively, which is easier to read
def vae_decode(x, *args, **kwargs):
print("Entering vae.decode")
return ldm_aef.decode(x).unsqueeze(0)
pipe.vae.decode = vae_decode
And when you generate, what does the
print(unexpected_keys)
gives?
_IncompatibleKeys(missing_keys=['encoder.conv_in.weight', 'encoder.conv_in.bias', 'encoder.down.0.block.0.norm1.weight', 'encoder.down.0.block.0.norm1.bias', 'encoder.down.0.block.0.conv1.weight', 'encoder.down.0.block.0.conv1.bias', 'encoder.down.0.block.0.norm2.weight', 'encoder.down.0.block.0.norm2.bias', 'encoder.down.0.block.0.conv2.weight', 'encoder.down.0.block.0.conv2.bias', 'encoder.down.0.block.1.norm1.weight', 'encoder.down.0.block.1.norm1.bias', 'encoder.down.0.block.1.conv1.weight', 'encoder.down.0.block.1.conv1.bias', 'encoder.down.0.block.1.norm2.weight', 'encoder.down.0.block.1.norm2.bias', 'encoder.down.0.block.1.conv2.weight', 'encoder.down.0.block.1.conv2.bias', 'encoder.down.0.downsample.conv.weight', 'encoder.down.0.downsample.conv.bias', 'encoder.down.1.block.0.norm1.weight', 'encoder.down.1.block.0.norm1.bias', 'encoder.down.1.block.0.conv1.weight', 'encoder.down.1.block.0.conv1.bias', 'encoder.down.1.block.0.norm2.weight', 'encoder.down.1.block.0.norm2.bias', 'encoder.down.1.block.0.conv2.weight', 'encoder.down.1.block.0.conv2.bias', 'encoder.down.1.block.0.nin_shortcut.weight', 'encoder.down.1.block.0.nin_shortcut.bias', 'encoder.down.1.block.1.norm1.weight', 'encoder.down.1.block.1.norm1.bias', 'encoder.down.1.block.1.conv1.weight', 'encoder.down.1.block.1.conv1.bias', 'encoder.down.1.block.1.norm2.weight', 'encoder.down.1.block.1.norm2.bias', 'encoder.down.1.block.1.conv2.weight', 'encoder.down.1.block.1.conv2.bias', 'encoder.down.1.downsample.conv.weight', 'encoder.down.1.downsample.conv.bias', 'encoder.down.2.block.0.norm1.weight', 'encoder.down.2.block.0.norm1.bias', 'encoder.down.2.block.0.conv1.weight', 'encoder.down.2.block.0.conv1.bias', 'encoder.down.2.block.0.norm2.weight', 'encoder.down.2.block.0.norm2.bias', 'encoder.down.2.block.0.conv2.weight', 'encoder.down.2.block.0.conv2.bias', 'encoder.down.2.block.0.nin_shortcut.weight', 'encoder.down.2.block.0.nin_shortcut.bias', 'encoder.down.2.block.1.norm1.weight', 'encoder.down.2.block.1.norm1.bias', 'encoder.down.2.block.1.conv1.weight', 'encoder.down.2.block.1.conv1.bias', 'encoder.down.2.block.1.norm2.weight', 'encoder.down.2.block.1.norm2.bias', 'encoder.down.2.block.1.conv2.weight', 'encoder.down.2.block.1.conv2.bias', 'encoder.down.2.downsample.conv.weight', 'encoder.down.2.downsample.conv.bias', 'encoder.down.3.block.0.norm1.weight', 'encoder.down.3.block.0.norm1.bias', 'encoder.down.3.block.0.conv1.weight', 'encoder.down.3.block.0.conv1.bias', 'encoder.down.3.block.0.norm2.weight', 'encoder.down.3.block.0.norm2.bias', 'encoder.down.3.block.0.conv2.weight', 'encoder.down.3.block.0.conv2.bias', 'encoder.down.3.block.1.norm1.weight', 'encoder.down.3.block.1.norm1.bias', 'encoder.down.3.block.1.conv1.weight', 'encoder.down.3.block.1.conv1.bias', 'encoder.down.3.block.1.norm2.weight', 'encoder.down.3.block.1.norm2.bias', 'encoder.down.3.block.1.conv2.weight', 'encoder.down.3.block.1.conv2.bias', 'encoder.mid.block_1.norm1.weight', 'encoder.mid.block_1.norm1.bias', 'encoder.mid.block_1.conv1.weight', 'encoder.mid.block_1.conv1.bias', 'encoder.mid.block_1.norm2.weight', 'encoder.mid.block_1.norm2.bias', 'encoder.mid.block_1.conv2.weight', 'encoder.mid.block_1.conv2.bias', 'encoder.mid.attn_1.norm.weight', 'encoder.mid.attn_1.norm.bias', 'encoder.mid.attn_1.q.weight', 'encoder.mid.attn_1.q.bias', 'encoder.mid.attn_1.k.weight', 'encoder.mid.attn_1.k.bias', 'encoder.mid.attn_1.v.weight', 'encoder.mid.attn_1.v.bias', 'encoder.mid.attn_1.proj_out.weight', 'encoder.mid.attn_1.proj_out.bias', 'encoder.mid.block_2.norm1.weight', 'encoder.mid.block_2.norm1.bias', 'encoder.mid.block_2.conv1.weight', 'encoder.mid.block_2.conv1.bias', 'encoder.mid.block_2.norm2.weight', 'encoder.mid.block_2.norm2.bias', 'encoder.mid.block_2.conv2.weight', 'encoder.mid.block_2.conv2.bias', 'encoder.norm_out.weight', 'encoder.norm_out.bias', 'encoder.conv_out.weight', 'encoder.conv_out.bias', 'decoder.conv_in.weight', 'decoder.conv_in.bias', 'decoder.mid.block_1.norm1.weight', 'decoder.mid.block_1.norm1.bias', 'decoder.mid.block_1.conv1.weight', 'decoder.mid.block_1.conv1.bias', 'decoder.mid.block_1.norm2.weight', 'decoder.mid.block_1.norm2.bias', 'decoder.mid.block_1.conv2.weight', 'decoder.mid.block_1.conv2.bias', 'decoder.mid.attn_1.norm.weight', 'decoder.mid.attn_1.norm.bias', 'decoder.mid.attn_1.q.weight', 'decoder.mid.attn_1.q.bias', 'decoder.mid.attn_1.k.weight', 'decoder.mid.attn_1.k.bias', 'decoder.mid.attn_1.v.weight', 'decoder.mid.attn_1.v.bias', 'decoder.mid.attn_1.proj_out.weight', 'decoder.mid.attn_1.proj_out.bias', 'decoder.mid.block_2.norm1.weight', 'decoder.mid.block_2.norm1.bias', 'decoder.mid.block_2.conv1.weight', 'decoder.mid.block_2.conv1.bias', 'decoder.mid.block_2.norm2.weight', 'decoder.mid.block_2.norm2.bias', 'decoder.mid.block_2.conv2.weight', 'decoder.mid.block_2.conv2.bias', 'decoder.up.0.block.0.norm1.weight', 'decoder.up.0.block.0.norm1.bias', 'decoder.up.0.block.0.conv1.weight', 'decoder.up.0.block.0.conv1.bias', 'decoder.up.0.block.0.norm2.weight', 'decoder.up.0.block.0.norm2.bias', 'decoder.up.0.block.0.conv2.weight', 'decoder.up.0.block.0.conv2.bias', 'decoder.up.0.block.0.nin_shortcut.weight', 'decoder.up.0.block.0.nin_shortcut.bias', 'decoder.up.0.block.1.norm1.weight', 'decoder.up.0.block.1.norm1.bias', 'decoder.up.0.block.1.conv1.weight', 'decoder.up.0.block.1.conv1.bias', 'decoder.up.0.block.1.norm2.weight', 'decoder.up.0.block.1.norm2.bias', 'decoder.up.0.block.1.conv2.weight', 'decoder.up.0.block.1.conv2.bias', 'decoder.up.0.block.2.norm1.weight', 'decoder.up.0.block.2.norm1.bias', 'decoder.up.0.block.2.conv1.weight', 'decoder.up.0.block.2.conv1.bias', 'decoder.up.0.block.2.norm2.weight', 'decoder.up.0.block.2.norm2.bias', 'decoder.up.0.block.2.conv2.weight', 'decoder.up.0.block.2.conv2.bias', 'decoder.up.1.block.0.norm1.weight', 'decoder.up.1.block.0.norm1.bias', 'decoder.up.1.block.0.conv1.weight', 'decoder.up.1.block.0.conv1.bias', 'decoder.up.1.block.0.norm2.weight', 'decoder.up.1.block.0.norm2.bias', 'decoder.up.1.block.0.conv2.weight', 'decoder.up.1.block.0.conv2.bias', 'decoder.up.1.block.0.nin_shortcut.weight', 'decoder.up.1.block.0.nin_shortcut.bias', 'decoder.up.1.block.1.norm1.weight', 'decoder.up.1.block.1.norm1.bias', 'decoder.up.1.block.1.conv1.weight', 'decoder.up.1.block.1.conv1.bias', 'decoder.up.1.block.1.norm2.weight', 'decoder.up.1.block.1.norm2.bias', 'decoder.up.1.block.1.conv2.weight', 'decoder.up.1.block.1.conv2.bias', 'decoder.up.1.block.2.norm1.weight', 'decoder.up.1.block.2.norm1.bias', 'decoder.up.1.block.2.conv1.weight', 'decoder.up.1.block.2.conv1.bias', 'decoder.up.1.block.2.norm2.weight', 'decoder.up.1.block.2.norm2.bias', 'decoder.up.1.block.2.conv2.weight', 'decoder.up.1.block.2.conv2.bias', 'decoder.up.1.upsample.conv.weight', 'decoder.up.1.upsample.conv.bias', 'decoder.up.2.block.0.norm1.weight', 'decoder.up.2.block.0.norm1.bias', 'decoder.up.2.block.0.conv1.weight', 'decoder.up.2.block.0.conv1.bias', 'decoder.up.2.block.0.norm2.weight', 'decoder.up.2.block.0.norm2.bias', 'decoder.up.2.block.0.conv2.weight', 'decoder.up.2.block.0.conv2.bias', 'decoder.up.2.block.1.norm1.weight', 'decoder.up.2.block.1.norm1.bias', 'decoder.up.2.block.1.conv1.weight', 'decoder.up.2.block.1.conv1.bias', 'decoder.up.2.block.1.norm2.weight', 'decoder.up.2.block.1.norm2.bias', 'decoder.up.2.block.1.conv2.weight', 'decoder.up.2.block.1.conv2.bias', 'decoder.up.2.block.2.norm1.weight', 'decoder.up.2.block.2.norm1.bias', 'decoder.up.2.block.2.conv1.weight', 'decoder.up.2.block.2.conv1.bias', 'decoder.up.2.block.2.norm2.weight', 'decoder.up.2.block.2.norm2.bias', 'decoder.up.2.block.2.conv2.weight', 'decoder.up.2.block.2.conv2.bias', 'decoder.up.2.upsample.conv.weight', 'decoder.up.2.upsample.conv.bias', 'decoder.up.3.block.0.norm1.weight', 'decoder.up.3.block.0.norm1.bias', 'decoder.up.3.block.0.conv1.weight', 'decoder.up.3.block.0.conv1.bias', 'decoder.up.3.block.0.norm2.weight', 'decoder.up.3.block.0.norm2.bias', 'decoder.up.3.block.0.conv2.weight', 'decoder.up.3.block.0.conv2.bias', 'decoder.up.3.block.1.norm1.weight', 'decoder.up.3.block.1.norm1.bias', 'decoder.up.3.block.1.conv1.weight', 'decoder.up.3.block.1.conv1.bias', 'decoder.up.3.block.1.norm2.weight', 'decoder.up.3.block.1.norm2.bias', 'decoder.up.3.block.1.conv2.weight', 'decoder.up.3.block.1.conv2.bias', 'decoder.up.3.block.2.norm1.weight', 'decoder.up.3.block.2.norm1.bias', 'decoder.up.3.block.2.conv1.weight', 'decoder.up.3.block.2.conv1.bias', 'decoder.up.3.block.2.norm2.weight', 'decoder.up.3.block.2.norm2.bias', 'decoder.up.3.block.2.conv2.weight', 'decoder.up.3.block.2.conv2.bias', 'decoder.up.3.upsample.conv.weight', 'decoder.up.3.upsample.conv.bias', 'decoder.norm_out.weight', 'decoder.norm_out.bias', 'decoder.conv_out.weight', 'decoder.conv_out.bias', 'quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias'], unexpected_keys=['ldm_decoder', 'optimizer', 'params']) you should check that the decoder keys are correctly matched
So in your case it's (1). You need to make sure that no "decoder.*" keys are printed.
The unexpected keys are the keys in your state dict. So you should do unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]
See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase
So in your case it's (1). You need to make sure that no "decoder.*" keys are printed. The unexpected keys are the keys in your state dict. So you should do
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]
See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase
Thank you very much for your help! I solved my problem. Maybe you can change the code in Generate with Diffusers. state_dict = torch.load("sd2_decoder.pth") -> state_dict = torch.load("sd2_decoder.pth")['ldm_decoder']
If you do it with the weights I provided, there is no need to do it since the state_dict
is the ldm_decoder
directly
I'll add something in the readme! Thx for the follow-up on this!
Sorry for the late reply :( I have also tried out the solution and solved the problem. Thank you all for your solutions and follow-up!
Hi there! I have tried the weight of decoder you provided here: WM weights of latent decoder and I generate an image using code provided in README.md:
Then I use this image trying to extract message in
decoding.ipynb
, however it turns out that it cannot be extracted correctly, and the bit accuracy is only about 50% to 60%. I am wondering is there anything wrong with my usage? Thanks a lot!