Closed HosamGen closed 1 year ago
Did you solve the problem, I had the same problem but didn't know how to solve it
Here :wave:
Maintainer of the diffusers
library here - should we try to add a prompt-to-prompt pipeline to diffusers
to make sure things are actively maintained?
@HosamGen This error happens because you are using a newer version of diffusers library. If you downgrade to diffusers==0.3.0 it should solve the problem. But when I downgrade, and run this line: ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=MY_TOKEN, scheduler=scheduler).to(device)
it does not proceed and gets stock there.
@aliasgharkhani For the latent model notebook this would fix it, but for the stable model this does not proceed as I face the same issue you mentioned.
@patrickvonplaten Sure, that would be great.
I manage to find a work around using diffusers==0.14.0. Reason of failure: in the new diffusers library, the forward function signature of CrossAttention class changes. In the old 0.3.0 version, the default forward function does not leverage the mask input (requried by Prompt2Prompt), hence Prompt2Prompt modify the forward function in ptp_utils.py with the register_attention_control
function. Yet this modified forward conflicts with the updated forward signature in the new diffusion library (0.14.0).
def register_attention_control(model, controller):
, change def forward(x, context=None, mask=None):
to def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
. Then add x = hidden_states
, context = encoder_hidden_states
and mask = attention_mask
inside the function.In diffusers.models.cross_attention.py, add the following functions in CrossAttention class. This is because the two functions are called inside the registered forward function, yet they are removed from the CrossAttention class in the newer diffusers library.
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
register_attention_control
to call them instead.@yue-zhongqi This works, thank you very much.
that's works
The suggestions from @yue-zhongqi work. But instead of directly changing the CrossAttention
Class in the diffusers.models.cross_attention.py
, one can also reuse the official functions head_to_batch_dim
and batch_to_head_dim
in the newer version of diffuser
to replace reshape_heads_to_batch_dim
and reshape_batch_dim_to_heads
. In short, one can directly replace the original def forward(x, context=None, mask=None)
function in def register_attention_control(model, controller)
of ptp_utils.py
with the following codes:
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
x = hidden_states
context = encoder_hidden_states
mask = attention_mask
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
attn = controller(attn, is_cross, place_in_unet)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
I am trying to run any of the jupyter notebooks to test the code but I am facing this error in the line where the prompts are passed to the model. The code cell and the error are the following:
g_cpu = torch.Generator().manual_seed(8888) prompts = ["A painting of a squirrel eating a burger"] controller = AttentionStore() image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu) show_cross_attention(controller, res=16, from_where=("up", "down"))
TypeError Traceback (most recent call last) Cell In[9], line 4 2 prompts = ["A painting of a squirrel eating a burger"] 3 controller = AttentionStore() ----> 4 image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu) 5 show_cross_attention(controller, res=16, from_where=("up", "down"))
Cell In[6], line 6, in run_and_display(prompts, controller, latent, run_baseline, generator) 4 images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) 5 print("with prompt-to-prompt") ----> 6 images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE) 7 ptp_utils.view_images(images) 8 return images, x_t
File ~/.conda/envs/prompt/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, *kwargs):
26 with self.clone():
---> 27 return func(args, kwargs)
File ~/Downloads/prompt-to-prompt/ptp_utils.py:167, in text2image_ldm_stable(model, prompt, controller, num_inference_steps, guidance_scale, generator, latent, low_resource) 165 model.scheduler.set_timesteps(num_inference_steps) 166 for t in tqdm(model.scheduler.timesteps): --> 167 latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) ... -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'
Any suggestions?