Closed JohnsonXi closed 2 months ago
After updating the latest code, new problems arise
Solved! You need to modify the following files in folder "ControlNeXt-SDXL":
for j, _ in enumerate(range(args.num_validation_images)):
with inference_ctx:
image = pipeline.__call__(
line 469
def adjust_checkpoint_state_dict(state_dict, model_state_dict):
for key, value in state_dict.items():
print(f"{key}: {value.shape}")
adjusted_state_dict = {}
for key, value in state_dict.items():
if key in model_state_dict:
if value.shape == model_state_dict[key].shape:
adjusted_state_dict[key] = value
print(f"Shape mismatch for {key}, skipping this parameter.")
print(f"{key} not in model state dict, skipping this parameter.")
return adjusted_state_dict
def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): if not load_weight_increasement: state_dict = load_file(safetensors_path) state_dict = adjust_checkpoint_state_dict(state_dict, model.state_dict()) model.load_state_dict(state_dict, strict=strict) else: state_dict = load_file(safetensors_path) pretrained_state_dict = model.state_dict() for k in state_dict.keys(): state_dict[k] = state_dict[k] + pretrained_state_dict[k] model.load_state_dict(state_dict, strict=False) ...
3. models/
You need to download the model "madebyollin_sdxl-vae-fp16-fix" from huggingface in advance
line 1113
device = self._execution_device
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
# 3. Encode input prompt
and result is as below:
<img width="858" alt="截屏2024-07-12 19 20 38" src="">
@JohnsonXi Fixed. Sorry for the delay. @RobertLau666 My bad. We changed the model structure for SDXL recently to make it lighter. Please update the code to the newest version and re-download the weight from
Nice, but the following line is still needed
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
result is as below, better
Nice, but the following line is still needed
self.vae = AutoencoderKL.from_pretrained("madebyollin_sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
result is as below, better
I see. Thanks for the comment!
I succeeded in running controlnext-sd1.5, but can not on SDXL, the code of SDXL load model is different from SD1.5.