vislearn / ControlNet-XS

Apache License 2.0
411 stars 12 forks source link

open_clip.create_model_and_transforms fuction problem when inferencing #27

Open lavinal712 opened 3 months ago

lavinal712 commented 3 months ago

When I ran the example code in README.md, I met a strange problem.

import scripts.control_utils as cu
import torch
from PIL import Image

path_to_config = 'configs/inference/sdxl/sdxl_encD_canny_48m.yaml'
model = cu.create_model(path_to_config).to('cuda')

image_path = 'IMAGES/00007.png'

canny_high_th = 250
canny_low_th = 100
size = 768
num_samples=2

image = cu.get_image(image_path, size=size)
edges = cu.get_canny_edges(image, low_th=canny_low_th, high_th=canny_high_th)

samples, controls = cu.get_sdxl_sample(
    guidance=edges,
    ddim_steps=10,
    num_samples=num_samples,
    model=model,
    shape=[4, size // 8, size // 8],
    control_scale=0.95,
    prompt='cinematic, shoe in the streets, made from meat, photorealistic shoe, highly detailed',
    n_prompt='lowres, bad anatomy, worst quality, low quality',
)

Image.fromarray(cu.create_image_grid(samples)).save('00007.png')

The error occured in the following line in File "ControlNet-XS\sgm\modules\encoders\modules.py", line 428:

model, _, _ = open_clip.create_model_and_transforms(

To fix it, I downloaded the laion CLIP-ViT-H-14-laion2B-s32B-b79K model manually and put it in a directory, then I use the model:

model, _, _ = open_clip.create_model_and_transforms(
            arch,
            device=torch.device("cpu"),
            # pretrained=version,
            pretrained="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
        )

Then I met the error I could not fix:

RuntimeError: Error(s) in loading state_dict for CLIP:
    Missing key(s) in state_dict: "visual.transformer.resblocks.32.ln_1.weight", "visual.transformer.resblocks.32.ln_1.bias", "visual.transformer.resblocks.32.attn.in_proj_weight", "visual.transformer.resblocks.32.attn.in_proj_bias", "visual.transformer.resblocks.32.attn.out_proj.weight", "visual.transformer.resblocks.32.attn.out_proj.bias", 
......
"transformer.resblocks.31.mlp.c_proj.weight", "transformer.resblocks.31.mlp.c_proj.bias". 
    size mismatch for positional_embedding: copying a param with shape torch.Size([77, 1024]) from checkpoint, the shape in current model is torch.Size([77, 1280]).
    size mismatch for text_projection: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    ......

I wonder if it is the problem of the model or something else.