Closed JohnsonXi closed 2 months ago
@Eugeoter
After updating the latest code, new problems arise
Solved! You need to modify the following files in folder "ControlNeXt-SDXL":
for j, _ in enumerate(range(args.num_validation_images)):
with inference_ctx:
image = pipeline.__call__(
prompt=validation_prompt,
controlnet_image=validation_image,
num_inference_steps=20,
generator=generator,
negative_prompt=negative_prompt,
width=args.resolution,
height=args.resolution,
).images[0]
image.save(f"{j}.png")
line 469
...
def adjust_checkpoint_state_dict(state_dict, model_state_dict):
for key, value in state_dict.items():
print(f"{key}: {value.shape}")
adjusted_state_dict = {}
for key, value in state_dict.items():
if key in model_state_dict:
if value.shape == model_state_dict[key].shape:
adjusted_state_dict[key] = value
else:
print(f"Shape mismatch for {key}, skipping this parameter.")
else:
print(f"{key} not in model state dict, skipping this parameter.")
return adjusted_state_dict
def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): if not load_weight_increasement: state_dict = load_file(safetensors_path) state_dict = adjust_checkpoint_state_dict(state_dict, model.state_dict()) model.load_state_dict(state_dict, strict=strict) else: state_dict = load_file(safetensors_path) pretrained_state_dict = model.state_dict() for k in state_dict.keys(): state_dict[k] = state_dict[k] + pretrained_state_dict[k] model.load_state_dict(state_dict, strict=False) ...
3. models/pipeline_controlnext.py
You need to download the model "madebyollin_sdxl-vae-fp16-fix" from huggingface in advance
line 1113
...
device = self._execution_device
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
# 3. Encode input prompt
...
and result is as below:
<img width="858" alt="截屏2024-07-12 19 20 38" src="https://github.com/user-attachments/assets/0c705f81-2122-42c2-a6a2-78e446022712">
@JohnsonXi Fixed. Sorry for the delay. @RobertLau666 My bad. We changed the model structure for SDXL recently to make it lighter. Please update the code to the newest version and re-download the weight from https://huggingface.co/Pbihao/ControlNeXt.
Nice, but the following line is still needed
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
result is as below, better
Nice, but the following line is still needed
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
result is as below, better
I see. Thanks for the comment!
I succeeded in running controlnext-sd1.5, but can not on SDXL, the code of SDXL load model is different from SD1.5.