dvlab-research / ControlNeXt

Controllable video and image Generation, SVD, Animate Anyone, ControlNet, ControlNeXt, LoRA
Apache License 2.0
1.39k stars 66 forks source link

contrlnext-sdxl load safetensors error #5

Closed JohnsonXi closed 2 months ago

JohnsonXi commented 3 months ago
image

I succeeded in running controlnext-sd1.5, but can not on SDXL, the code of SDXL load model is different from SD1.5.

Pbihao commented 3 months ago

@Eugeoter

RobertLau666 commented 3 months ago

After updating the latest code, new problems arise

截屏2024-07-10 18 34 15

Solved! You need to modify the following files in folder "ControlNeXt-SDXL":

  1. run_controlnext.py line 103
        for j, _ in enumerate(range(args.num_validation_images)):
            with inference_ctx:
                image = pipeline.__call__(
                    prompt=validation_prompt,
                    controlnet_image=validation_image,
                    num_inference_steps=20,
                    generator=generator,
                    negative_prompt=negative_prompt,
                    width=args.resolution,
                    height=args.resolution,
                ).images[0]
            image.save(f"{j}.png")

    line 469

    
        ...
    def adjust_checkpoint_state_dict(state_dict, model_state_dict):
    for key, value in state_dict.items():
        print(f"{key}: {value.shape}")
    adjusted_state_dict = {}
    for key, value in state_dict.items():
        if key in model_state_dict:
            if value.shape == model_state_dict[key].shape:
                adjusted_state_dict[key] = value
            else:
                print(f"Shape mismatch for {key}, skipping this parameter.")
        else:
            print(f"{key} not in model state dict, skipping this parameter.")
    return adjusted_state_dict

def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): if not load_weight_increasement: state_dict = load_file(safetensors_path) state_dict = adjust_checkpoint_state_dict(state_dict, model.state_dict()) model.load_state_dict(state_dict, strict=strict) else: state_dict = load_file(safetensors_path) pretrained_state_dict = model.state_dict() for k in state_dict.keys(): state_dict[k] = state_dict[k] + pretrained_state_dict[k] model.load_state_dict(state_dict, strict=False) ...

3. models/pipeline_controlnext.py
You need to download the model "madebyollin_sdxl-vae-fp16-fix" from huggingface in advance
line 1113
    ...
    device = self._execution_device
    self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
    # 3. Encode input prompt
    ...

and result is as below:
<img width="858" alt="截屏2024-07-12 19 20 38" src="https://github.com/user-attachments/assets/0c705f81-2122-42c2-a6a2-78e446022712">
Eugeoter commented 3 months ago

@JohnsonXi Fixed. Sorry for the delay. @RobertLau666 My bad. We changed the model structure for SDXL recently to make it lighter. Please update the code to the newest version and re-download the weight from https://huggingface.co/Pbihao/ControlNeXt.

RobertLau666 commented 3 months ago

Nice, but the following line is still needed

self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)

result is as below, better

截屏2024-07-12 23 16 41
Eugeoter commented 3 months ago

Nice, but the following line is still needed

self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)

result is as below, better 截屏2024-07-12 23 16 41

I see. Thanks for the comment!