CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.48k stars 1.5k forks source link

The problem with the checkpoint finetuning. #24

Open fellow-tom opened 2 years ago

fellow-tom commented 2 years ago

Hi. First of all, thank you for this wonderful repository. I try to run a training and have the following problem:

I downloaded a small part of the imagenet dataset (2Gb) and unzipped it. There were only images, so I had to change the "./ldm/data/imagenet.py" a bit to be able to load my dataset. The output gave example["image"] and example["LR_image"] as required.

Then I fixed a few lines in "./models/ldm/bsr_sr/config.yaml", namely in train and validation I changed target to the path to imagenet.py.

Then I downloaded your ckpt file from notebook_helpers.py and decided to try to finetune the weight.

CUDA_VISIBLE_DEVICES=0 python main.py \ --base "./models/ldm/bsr_sr/config.yaml" \ --name "test" \ --resume_from_checkpoint "./logs/diffusion/superresolution_bsr/last.yaml/?dl=1" \ -t --gpus=0

But I got an error:

RuntimeError: Error(s) in loading state_dict for LatentDiffusion: Unexpected key(s) in state_dict: "ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas".

If I read the weights, delete those 4 keys and write to a new file, the training starts fine. Do I understand correctly that without them, the training will not work good? If I start the training from scratch, the resulting checkpoints will not contain these 4 keys at all. Can you tell me what I'm doing wrong?


And another small question: I separately trained the autoencoder (first_stage_models), got the checkpoint, but I can't find where to specify it when training the diffusion model (ldm). Perhaps the autoencoder is not involved in this step, but then where do I specify it if I want to run an inference with my weights?

sangyun884 commented 2 years ago

Did you settle the issue?

gigadeplex commented 2 years ago

Hi. First of all, thank you for this wonderful repository. I try to run a training and have the following problem:

I downloaded a small part of the imagenet dataset (2Gb) and unzipped it. There were only images, so I had to change the "./ldm/data/imagenet.py" a bit to be able to load my dataset. The output gave example["image"] and example["LR_image"] as required.

Then I fixed a few lines in "./models/ldm/bsr_sr/config.yaml", namely in train and validation I changed target to the path to imagenet.py.

Then I downloaded your ckpt file from notebook_helpers.py and decided to try to finetune the weight.

CUDA_VISIBLE_DEVICES=0 python main.py --base "./models/ldm/bsr_sr/config.yaml" --name "test" --resume_from_checkpoint "./logs/diffusion/superresolution_bsr/last.yaml/?dl=1" -t --gpus=0

But I got an error:

RuntimeError: Error(s) in loading state_dict for LatentDiffusion: Unexpected key(s) in state_dict: "ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas".

If I read the weights, delete those 4 keys and write to a new file, the training starts fine. Do I understand correctly that without them, the training will not work good? If I start the training from scratch, the resulting checkpoints will not contain these 4 keys at all. Can you tell me what I'm doing wrong?

And another small question: I separately trained the autoencoder (first_stage_models), got the checkpoint, but I can't find where to specify it when training the diffusion model (ldm). Perhaps the autoencoder is not involved in this step, but then where do I specify it if I want to run an inference with my weights?

did you solve this?

samuelesabella commented 1 year ago

Hi. Any news?

SantiUsma commented 1 year ago

Hi everyone. I had the same problem so I did something weird. The error said the "ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas" are unexpected keys in the checkpoint so it is not necessary. I used the following code to delete this keys in the dictionary:

import torch import os

dictionary=torch.load("model.ckpt",map_location='cpu')

new_dict={} keys=dictionary['state_dict'].keys()

for i,params in enumerate(keys): if 'ddim_sigmas' != params and 'ddim_alphas' !=params and 'ddim_alphas_prev' !=params and 'ddim_sqrt_one_minus_alphas' !=params and 'cond_stage_model' not in params: new_dict[params]=dictionary['state_dict'][params]

If you want to change the conditional keys in the pretrained model to a personal conditional model with random weights.

''' model=torch.nn.Linear(2048,640)

for param in model.parameters(): param.data = torch.randn(param.data.size())

s_d=model.state_dict() keys=s_d.keys() for i,params in enumerate(keys): new_dict['cond_stage_model.model.'+params]=s_d[params] '''

dictionary['state_dict']=new_dict

torch.save(dictionary,"new_model.ckpt")

This code copies all the parameters of the checkpoint except those that cause the problem. I hope this helps you!

SantiUsma commented 1 year ago

And about your first_stage_models question, you could specify the autoencoder first_stage checkpoint path in the yaml config. In the model.params.first_stage_config you add a ckpt_path: "path_to_model.ckpt" Something like this:

first_stage_config: target: ldm.models.autoencoder.VQModelInterface params: embed_dim: 4 n_embed: 16384 ckpt_path: "/media/SSD3/asusma/Tesis/latent-diffusion/logs/VQ8/pretrained/model.ckpt" ddconfig: ...

CreamyLong commented 1 year ago

Hi everyone. I had the same problem so I did something weird. The error said the "ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas" are unexpected keys in the checkpoint so it is not necessary. I used the following code to delete this keys in the dictionary:

import torch import os

dictionary=torch.load("model.ckpt",map_location='cpu')

new_dict={} keys=dictionary['state_dict'].keys()

for i,params in enumerate(keys): if 'ddim_sigmas' != params and 'ddim_alphas' !=params and 'ddim_alphas_prev' !=params and 'ddim_sqrt_one_minus_alphas' !=params and 'cond_stage_model' not in params: new_dict[params]=dictionary['state_dict'][params]

If you want to change the conditional keys in the pretrained model to a personal conditional model with random weights. ''' model=torch.nn.Linear(2048,640)

for param in model.parameters(): param.data = torch.randn(param.data.size())

s_d=model.state_dict() keys=s_d.keys() for i,params in enumerate(keys): new_dict['cond_stage_model.model.'+params]=s_d[params] '''

dictionary['state_dict']=new_dict

torch.save(dictionary,"new_model.ckpt")

This code copies all the parameters of the checkpoint except those that cause the problem. I hope this helps you!

I am finetuning the model in layout2img-openimages256, The model is

| Name | Type | Params 0 | model | DiffusionWrapper | 246 M 1 | model_ema | LitEma | 0
2 | first_stage_model | VQModelInterface | 55.3 M 3 | cond_stage_model | BERTEmbedder | 58.9 M

AND got error RuntimeError: Error(s) in loading state_dict for Layout2ImgDiffusion: Missing key(s) in state_dict:xxxxxxxxxxxxxxxxxxx

do you have any idea how I deal with it? @SantiUsma

SantiUsma commented 1 year ago

Yes! delet the " and 'cond_stage_model' not in params" in the " for i,params in enumerate(keys): " and try it again.

omrastogi commented 5 months ago

@SantiUsma I am getting the following error after following this: RuntimeError: Error(s) in loading state_dict for LatentDiffusion: Missing key(s) in state_dict: "cond_stage_model.encoder.conv_in.weight", "cond_stage_model.encoder.conv_in.bias", "cond_stage_model.encoder.down.0.block.0.norm1.weight", "cond_stage_model.encoder.down.0.block.0.norm1.bias", "cond_stage_model.encoder.down.0.block.0.conv1.weight", "cond_stage_model.encoder.down.0.block.0.conv1.bias", "cond_stage_model.encoder.down.0.block.0.norm2.weight", "cond_stage_model.encoder.down.0.block.0.norm2.bias", "cond_stage_model.encoder.down.0.block.0.conv2.weight", "cond_stage_model.encoder.down.0.block.0.conv2.bias", "cond_stage_model.encoder.down.0.block.1.norm1.weight", "cond_stage_model.encoder.down.0.block.1.norm1.bias", "cond_stage_model.encoder.down.0.block.1.conv1.weight", "cond_stage_model.encoder.down.0.block.1.conv1.bias", "cond_stage_model.encoder.down.0.block.1.norm2.weight", "cond_stage_model.encoder.down.0.block.1.norm2.bias", "cond_stage_model.encoder.down.0.block.1.conv2.weight", "cond_stage_model.encoder.down.0.block.1.conv2.bias", "cond_stage_model.encoder.down.0.downsample.conv.weight", "cond_stage_model.encoder.down.0.downsample.conv.bias", "cond_stage_model.encoder.down.1.block.0.norm1.weight", "cond_stage_model.encoder.down.1.block.0.norm1.bias", "cond_stage_model.encoder.down.1.block.0.conv1.weight", "cond_stage_model.encoder.down.1.block.0.conv1.bias", "cond_stage_model.encoder.down.1.block.0.norm2.weight", "cond_stage_model.encoder.down.1.block.0.norm2.bias", "cond_stage_model.encoder.down.1.block.0.conv2.weight", "cond_stage_model.encoder.down.1.block.0.conv2.bias", "cond_stage_model.encoder.down.1.block.0.nin_shortcut.weight", "cond_stage_model.encoder.down.1.block.0.nin_shortcut.bias", "cond_stage_model.encoder.down.1.block.1.norm1.weight", "cond_stage_model.encoder.down.1.block.1.norm1.bias", "cond_stage_model.encoder.down.1.block.1.conv1.weight", "cond_stage_model.encoder.down.1.block.1.conv1.bias", "cond_stage_model.encoder.down.1.block.1.norm2.weight", "cond_stage_model.encoder.down.1.block.1.norm2.bias", "cond_stage_model.encoder.down.1.block.1.conv2.weight", "cond_stage_model.encoder.down.1.block.1.conv2.bias", "cond_stage_model.encoder.down.1.downsample.conv.weight", "cond_stage_model.encoder.down.1.downsample.conv.bias", "cond_stage_model.encoder.down.2.block.0.norm1.weight", "cond_stage_model.encoder.down.2.block.0.norm1.bias", "cond_stage_model.encoder.down.2.block.0.conv1.weight", "cond_stage_model.encoder.down.2.block.0.conv1.bias", "cond_stage_model.encoder.down.2.block.0.norm2.weight", "cond_stage_model.encoder.down.2.block.0.norm2.bias", "cond_stage_model.encoder.down.2.block.0.conv2.weight", "cond_stage_model.encoder.down.2.block.0.conv2.bias", "cond_stage_model.encoder.down.2.block.0.nin_shortcut.weight", "cond_stage_model.encoder.down.2.block.0.nin_shortcut.bias", "cond_stage_model.encoder.down.2.block.1.norm1.weight", "cond_stage_model.encoder.down.2.block.1.norm1.bias", "cond_stage_model.encoder.down.2.block.1.conv1.weight", "cond_stage_model.encoder.down.2.block.1.conv1.bias", "cond_stage_model.encoder.down.2.block.1.norm2.weight", "cond_stage_model.encoder.down.2.block.1.norm2.bias", "cond_stage_model.encoder.down.2.block.1.conv2.weight", "cond_stage_model.encoder.down.2.block.1.conv2.bias", "cond_stage_model.encoder.mid.block_1.norm1.weight", "cond_stage_model.encoder.mid.block_1.norm1.bias", "cond_stage_model.encoder.mid.block_1.conv1.weight", "cond_stage_model.encoder.mid.block_1.conv1.bias", "cond_stage_model.encoder.mid.block_1.norm2.weight", "cond_stage_model.encoder.mid.block_1.norm2.bias", "cond_stage_model.encoder.mid.block_1.conv2.weight", "cond_stage_model.encoder.mid.block_1.conv2.bias", "cond_stage_model.encoder.mid.block_2.norm1.weight", "cond_stage_model.encoder.mid.block_2.norm1.bias", "cond_stage_model.encoder.mid.block_2.conv1.weight", "cond_stage_model.encoder.mid.block_2.conv1.bias", "cond_stage_model.encoder.mid.block_2.norm2.weight", "cond_stage_model.encoder.mid.block_2.norm2.bias", "cond_stage_model.encoder.mid.block_2.conv2.weight", "cond_stage_model.encoder.mid.block_2.conv2.bias", "cond_stage_model.encoder.norm_out.weight", "cond_stage_model.encoder.norm_out.bias", "cond_stage_model.encoder.conv_out.weight", "cond_stage_model.encoder.conv_out.bias", "cond_stage_model.decoder.conv_in.weight", "cond_stage_model.decoder.conv_in.bias", "cond_stage_model.decoder.mid.block_1.norm1.weight", "cond_stage_model.decoder.mid.block_1.norm1.bias", "cond_stage_model.decoder.mid.block_1.conv1.weight", "cond_stage_model.decoder.mid.block_1.conv1.bias", "cond_stage_model.decoder.mid.block_1.norm2.weight", "cond_stage_model.decoder.mid.block_1.norm2.bias", "cond_stage_model.decoder.mid.block_1.conv2.weight", "cond_stage_model.decoder.mid.block_1.conv2.bias", "cond_stage_model.decoder.mid.block_2.norm1.weight", "cond_stage_model.decoder.mid.block_2.norm1.bias", "cond_stage_model.decoder.mid.block_2.conv1.weight", "cond_stage_model.decoder.mid.block_2.conv1.bias", "cond_stage_model.decoder.mid.block_2.norm2.weight", "cond_stage_model.decoder.mid.block_2.norm2.bias", "cond_stage_model.decoder.mid.block_2.conv2.weight", "cond_stage_model.decoder.mid.block_2.conv2.bias", "cond_stage_model.decoder.up.0.block.0.norm1.weight", "cond_stage_model.decoder.up.0.block.0.norm1.bias", "cond_stage_model.decoder.up.0.block.0.conv1.weight", "cond_stage_model.decoder.up.0.block.0.conv1.bias", "cond_stage_model.decoder.up.0.block.0.norm2.weight", "cond_stage_model.decoder.up.0.block.0.norm2.bias", "cond_stage_model.decoder.up.0.block.0.conv2.weight", "cond_stage_model.decoder.up.0.block.0.conv2.bias", "cond_stage_model.decoder.up.0.block.0.nin_shortcut.weight", "cond_stage_model.decoder.up.0.block.0.nin_shortcut.bias", "cond_stage_model.decoder.up.0.block.1.norm1.weight", "cond_stage_model.decoder.up.0.block.1.norm1.bias", "cond_stage_model.decoder.up.0.block.1.conv1.weight", "cond_stage_model.decoder.up.0.block.1.conv1.bias", "cond_stage_model.decoder.up.0.block.1.norm2.weight", "cond_stage_model.decoder.up.0.block.1.norm2.bias", "cond_stage_model.decoder.up.0.block.1.conv2.weight", "cond_stage_model.decoder.up.0.block.1.conv2.bias", "cond_stage_model.decoder.up.0.block.2.norm1.weight", "cond_stage_model.decoder.up.0.block.2.norm1.bias", "cond_stage_model.decoder.up.0.block.2.conv1.weight", "cond_stage_model.decoder.up.0.block.2.conv1.bias", "cond_stage_model.decoder.up.0.block.2.norm2.weight", "cond_stage_model.decoder.up.0.block.2.norm2.bias", "cond_stage_model.decoder.up.0.block.2.conv2.weight", "cond_stage_model.decoder.up.0.block.2.conv2.bias", "cond_stage_model.decoder.up.1.block.0.norm1.weight", "cond_stage_model.decoder.up.1.block.0.norm1.bias", "cond_stage_model.decoder.up.1.block.0.conv1.weight", "cond_stage_model.decoder.up.1.block.0.conv1.bias", "cond_stage_model.decoder.up.1.block.0.norm2.weight", "cond_stage_model.decoder.up.1.block.0.norm2.bias", "cond_stage_model.decoder.up.1.block.0.conv2.weight", "cond_stage_model.decoder.up.1.block.0.conv2.bias", "cond_stage_model.decoder.up.1.block.0.nin_shortcut.weight", "cond_stage_model.decoder.up.1.block.0.nin_shortcut.bias", "cond_stage_model.decoder.up.1.block.1.norm1.weight", "cond_stage_model.decoder.up.1.block.1.norm1.bias", "cond_stage_model.decoder.up.1.block.1.conv1.weight", "cond_stage_model.decoder.up.1.block.1.conv1.bias", "cond_stage_model.decoder.up.1.block.1.norm2.weight", "cond_stage_model.decoder.up.1.block.1.norm2.bias", "cond_stage_model.decoder.up.1.block.1.conv2.weight", "cond_stage_model.decoder.up.1.block.1.conv2.bias", "cond_stage_model.decoder.up.1.block.2.norm1.weight", "cond_stage_model.decoder.up.1.block.2.norm1.bias", "cond_stage_model.decoder.up.1.block.2.conv1.weight", "cond_stage_model.decoder.up.1.block.2.conv1.bias", "cond_stage_model.decoder.up.1.block.2.norm2.weight", "cond_stage_model.decoder.up.1.block.2.norm2.bias", "cond_stage_model.decoder.up.1.block.2.conv2.weight", "cond_stage_model.decoder.up.1.block.2.conv2.bias", "cond_stage_model.decoder.up.1.upsample.conv.weight", "cond_stage_model.decoder.up.1.upsample.conv.bias", "cond_stage_model.decoder.up.2.block.0.norm1.weight", "cond_stage_model.decoder.up.2.block.0.norm1.bias", "cond_stage_model.decoder.up.2.block.0.conv1.weight", "cond_stage_model.decoder.up.2.block.0.conv1.bias", "cond_stage_model.decoder.up.2.block.0.norm2.weight", "cond_stage_model.decoder.up.2.block.0.norm2.bias", "cond_stage_model.decoder.up.2.block.0.conv2.weight", "cond_stage_model.decoder.up.2.block.0.conv2.bias", "cond_stage_model.decoder.up.2.block.1.norm1.weight", "cond_stage_model.decoder.up.2.block.1.norm1.bias", "cond_stage_model.decoder.up.2.block.1.conv1.weight", "cond_stage_model.decoder.up.2.block.1.conv1.bias", "cond_stage_model.decoder.up.2.block.1.norm2.weight", "cond_stage_model.decoder.up.2.block.1.norm2.bias", "cond_stage_model.decoder.up.2.block.1.conv2.weight", "cond_stage_model.decoder.up.2.block.1.conv2.bias", "cond_stage_model.decoder.up.2.block.2.norm1.weight", "cond_stage_model.decoder.up.2.block.2.norm1.bias", "cond_stage_model.decoder.up.2.block.2.conv1.weight", "cond_stage_model.decoder.up.2.block.2.conv1.bias", "cond_stage_model.decoder.up.2.block.2.norm2.weight", "cond_stage_model.decoder.up.2.block.2.norm2.bias", "cond_stage_model.decoder.up.2.block.2.conv2.weight", "cond_stage_model.decoder.up.2.block.2.conv2.bias", "cond_stage_model.decoder.up.2.upsample.conv.weight", "cond_stage_model.decoder.up.2.upsample.conv.bias", "cond_stage_model.decoder.norm_out.weight", "cond_stage_model.decoder.norm_out.bias", "cond_stage_model.decoder.conv_out.weight", "cond_stage_model.decoder.conv_out.bias", "cond_stage_model.quantize.embedding.weight", "cond_stage_model.quant_conv.weight", "cond_stage_model.quant_conv.bias", "cond_stage_model.post_quant_conv.weight", "cond_stage_model.post_quant_conv.bias".

ultiwinter commented 2 weeks ago

Hello @omrastogi, thanks for putting the question out there, did you solve it meanwhile? I would be grateful if you could enlighten me.

Best