open-mmlab / PowerPaint

[ECCV 2024] PowerPaint, a versatile image inpainting model that supports text-guided object inpainting, object removal, image outpainting and shape-guided object inpainting with only a single model. 一个高质量多功能的图像修补模型,可以同时支持插入物体、移除物体、图像扩展、形状可控的物体生成,只需要一个模型
https://powerpaint.github.io/
MIT License
674 stars 41 forks source link

pipeline at dev branch #98

Open SmileTAT opened 1 month ago

SmileTAT commented 1 month ago

at dev brach, init pipeline as following code, but the output image is covered with a red layer `# brushnet-based version unet = UNet2DConditionModel.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", revision=None, torch_dtype=weight_dtype, ) text_encoder = CLIPTextModel.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder", revision=None, torch_dtype=weight_dtype, ) brushnet = BrushNetModel.from_unet(unet)

checkpoint_dir = "cache/huggingface/hub/models--JunhaoZhuang--PowerPaint-v2-1/snapshots/5ae2be3ac38b162df209b7ad5de036d339081e33"
base_model_path = os.path.join(checkpoint_dir, "realisticVisionV60B1_v51VAE")

pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
    base_model_path,
    brushnet=brushnet,
    text_encoder=text_encoder,
    torch_dtype=weight_dtype,
    low_cpu_mem_usage=False,
    safety_checker=None,
)
pipe.unet = UNet2DConditionModel.from_pretrained(
    base_model_path,
    subfolder="unet",
    revision=None,
    torch_dtype=weight_dtype,
)
load_model(
    pipe.brushnet,
    os.path.join(checkpoint_dir, "PowerPaint_Brushnet/diffusion_pytorch_model.safetensors"),
)

# IMPORTANT: add learnable tokens for task prompts into tokenizer
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"] # [v.placeholder_tokens for k, v in args.task_prompt.items()]
initializer_token = ["a", "a", "a"] # [v.initializer_token for k, v in args.task_prompt.items()]
num_vectors_per_token = [10, 10, 10] # [v.num_vectors_per_token for k, v in args.task_prompt.items()]
placeholder_token_ids = pipe.add_tokens(
    placeholder_tokens, initializer_token, num_vectors_per_token, initialize_parameters=True
)

text_state_dict = torch.load(os.path.join(checkpoint_dir, "PowerPaint_Brushnet/pytorch_model.bin"))

P_obj_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_obj']
P_ctxt_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_ctxt']
P_shape_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_shape']
wraped_weight = text_state_dict['text_model.embeddings.token_embedding.wrapped.weight']

text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_obj')
text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_ctxt')
text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_shape')
text_state_dict.pop('text_model.embeddings.token_embedding.wrapped.weight')

text_state_dict['text_model.embeddings.token_embedding.weight'] = torch.cat([
    wraped_weight, P_ctxt_weight, P_shape_weight, P_obj_weight])

msg = pipe.text_encoder.load_state_dict(
    text_state_dict, strict=False)
print(f'text load sd: {msg}')

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)`
zengyh1900 commented 1 month ago

hi @SmileTAT what do you mean by "red layers"? Would you mind sharing some screenshots?

SmileTAT commented 1 month ago

IMG_0426 input and output images @zengyh1900

SmileTAT commented 1 month ago

IMG_0426 input and output images @zengyh1900

differences between dev and main branch 1. dev: conditioning_latents = torch.concat([mask, conditioning_latents], 1) main: conditioning_latents = torch.concat([conditioning_latents, mask], 1) 2. dev: original_mask = (original_mask.sum(1)[:, None, :, :] > 0).to(image.dtype) main: original_mask = (original_mask.sum(1)[:, None, :, :] < 0).to(image.dtype)

zengyh1900 commented 1 month ago

oh I see. I had refactored the dev branch. If you are running our pretrained weights on dev branch, then it probably has some problems. Please run app in dev branch using your own trained weights