Closed dcaustin33 closed 1 year ago
+1 same here.
oh... I figured out. The latest version of diffusers not compatible with the codes. Downgrade from v0.19 to v0.14 works for me.
@hoveychen is right about this. As in the requirements.txt, we need to install diffusers==0.14
` KeyError Traceback (most recent call last) in <cell line: 10>()
8 tokenizer=concept_model.tokenizer, device=device, LOW_RESOURCE=LOW_RESOURCE)
9
---> 10 images, x_t = text2image(concept_model, prompts, controller, latent=x_t, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE,
11 generator=None, uncond_embeddings=uncond_embeddings)
12
4 frames /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, kwargs) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs) 116 117 return decorate_context
/content/utils.py in text2image(model, prompt, controller, num_inference_steps, guidance_scale, generator, latent, uncond_embeddings, start_time, return_type) 203 else: 204 context = torch.cat([uncondembeddings, text_embeddings]) --> 205 latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) 206 207 if return_type == 'image':
/content/utils.py in diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 139 noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 140 latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] --> 141 latents = controller.step_callback(latents) 142 return latents 143
/content/swapping_class.py in step_callback(self, x_t) 90 def step_callback(self, x_t): 91 if self.local_blend is not None: ---> 92 x_t = self.local_blend(x_t, self.attention_store) 93 return x_t 94
/content/swapping_class.py in call(self, x_t, attention_store) 24 if self.counter > self.start_blend: 25 ---> 26 maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 27 maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, 77) for item in maps] 28 maps = torch.cat(maps, dim=1)
KeyError: 'down_cross' `