wooyeolBaek / attention-map

🚀 Cross attention map tools for huggingface/diffusers
https://huggingface.co/spaces/We-Want-GPU/diffusers-cross-attention-map-SDXL-t2i
MIT License
147 stars 9 forks source link

Image to Image #1

Closed miriam-horovicz-ni closed 9 months ago

miriam-horovicz-ni commented 9 months ago

Thanks for your work! I am trying it with image to image inpaint but it fails. cross_attn_init()

controlnet=ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)

pipe =AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet,torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to('cuda')
device='cuda'
pipe.safety_checker = None
pipe.unet = register_cross_attention_hook(pipe.unet)
pipe = pipe.to("cuda")
init_img_path='cat.png'
strength=0.5
guidance_scale=7.5
generator = torch.Generator(device=device).manual_seed(1024)
prompt='cat'

init_img = load_image(f"Images/{init_img_path}").resize((512, 512))
init_mask = load_image(f"Images/{init_img_path.split('.')[0]}_mask.png").resize((512, 512))

image = pipe(prompt=prompt,
                    image=init_img,
                    control_image = init_mask,
                    strength = strength,
                    guidance_scale=guidance_scale,
                    generator=generator).images[0]
image.save('test.png')

dir_name = "attn_maps"
net_attn_maps = get_net_attn_map(image.size)
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
net_attn_maps

image

wooyeolBaek commented 9 months ago

Hi, Thanks for using my code :) Since I assumed it as a square image, reshaping attention map works for square images but not for rectangular ones. So I left it on TODO for improvement.

miriam-horovicz-ni commented 9 months ago

Thanks for update, I am resizing my image to 512X512 so it is square.

miriam-horovicz-ni commented 9 months ago

Thanks for update, I am resizing my image to 512X512 so it is square as you can see in the code above but still getting the error :(