google / prompt-to-prompt

Apache License 2.0
2.98k stars 279 forks source link

TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states' #44

Closed HosamGen closed 1 year ago

HosamGen commented 1 year ago

I am trying to run any of the jupyter notebooks to test the code but I am facing this error in the line where the prompts are passed to the model. The code cell and the error are the following:

g_cpu = torch.Generator().manual_seed(8888) prompts = ["A painting of a squirrel eating a burger"] controller = AttentionStore() image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu) show_cross_attention(controller, res=16, from_where=("up", "down"))


TypeError Traceback (most recent call last) Cell In[9], line 4 2 prompts = ["A painting of a squirrel eating a burger"] 3 controller = AttentionStore() ----> 4 image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu) 5 show_cross_attention(controller, res=16, from_where=("up", "down"))

Cell In[6], line 6, in run_and_display(prompts, controller, latent, run_baseline, generator) 4 images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) 5 print("with prompt-to-prompt") ----> 6 images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE) 7 ptp_utils.view_images(images) 8 return images, x_t

File ~/.conda/envs/prompt/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, *kwargs): 26 with self.clone(): ---> 27 return func(args, kwargs)

File ~/Downloads/prompt-to-prompt/ptp_utils.py:167, in text2image_ldm_stable(model, prompt, controller, num_inference_steps, guidance_scale, generator, latent, low_resource) 165 model.scheduler.set_timesteps(num_inference_steps) 166 for t in tqdm(model.scheduler.timesteps): --> 167 latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) ... -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

Any suggestions?

HiddenGalaxy commented 1 year ago

Did you solve the problem, I had the same problem but didn't know how to solve it

patrickvonplaten commented 1 year ago

Here :wave:

Maintainer of the diffusers library here - should we try to add a prompt-to-prompt pipeline to diffusers to make sure things are actively maintained?

aliasgharkhani commented 1 year ago

@HosamGen This error happens because you are using a newer version of diffusers library. If you downgrade to diffusers==0.3.0 it should solve the problem. But when I downgrade, and run this line: ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=MY_TOKEN, scheduler=scheduler).to(device) it does not proceed and gets stock there.

HosamGen commented 1 year ago

@aliasgharkhani For the latent model notebook this would fix it, but for the stable model this does not proceed as I face the same issue you mentioned.

HosamGen commented 1 year ago

@patrickvonplaten Sure, that would be great.

yue-zhongqi commented 1 year ago

I manage to find a work around using diffusers==0.14.0. Reason of failure: in the new diffusers library, the forward function signature of CrossAttention class changes. In the old 0.3.0 version, the default forward function does not leverage the mask input (requried by Prompt2Prompt), hence Prompt2Prompt modify the forward function in ptp_utils.py with the register_attention_control function. Yet this modified forward conflicts with the updated forward signature in the new diffusion library (0.14.0).

HosamGen commented 1 year ago

@yue-zhongqi This works, thank you very much.

momoshenchi commented 1 year ago

that's works

nihaomiao commented 1 year ago

The suggestions from @yue-zhongqi work. But instead of directly changing the CrossAttention Class in the diffusers.models.cross_attention.py, one can also reuse the official functions head_to_batch_dim and batch_to_head_dim in the newer version of diffuser to replace reshape_heads_to_batch_dim and reshape_batch_dim_to_heads. In short, one can directly replace the original def forward(x, context=None, mask=None) function in def register_attention_control(model, controller) of ptp_utils.py with the following codes:

  def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
      x = hidden_states
      context = encoder_hidden_states
      mask = attention_mask
      batch_size, sequence_length, dim = x.shape
      h = self.heads
      q = self.to_q(x)
      is_cross = context is not None
      context = context if is_cross else x
      k = self.to_k(context)
      v = self.to_v(context)
      q = self.head_to_batch_dim(q)
      k = self.head_to_batch_dim(k)
      v = self.head_to_batch_dim(v)

      sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

      if mask is not None:
          mask = mask.reshape(batch_size, -1)
          max_neg_value = -torch.finfo(sim.dtype).max
          mask = mask[:, None, :].repeat(h, 1, 1)
          sim.masked_fill_(~mask, max_neg_value)

      # attention, what we cannot get enough of
      attn = sim.softmax(dim=-1)
      attn = controller(attn, is_cross, place_in_unet)
      out = torch.einsum("b i j, b j d -> b i d", attn, v)
      out = self.batch_to_head_dim(out)
      return to_out(out)