Open nrasiwas opened 1 year ago
Similarly to #44 I've managed to make it work with diffusers==0.16.1, transformers==4.29.2 and torch==1.12.1 modifying the code for register_attention_control(model, controller)
in ptp_utils.py
:
def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
# def forward(x, context=None, mask=None):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
x = hidden_states
context = encoder_hidden_states
mask = attention_mask
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = reshape_heads_to_batch_dim(self,q)
k = reshape_heads_to_batch_dim(self,k)
v = reshape_heads_to_batch_dim(self,v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
attn = controller(attn, is_cross, place_in_unet)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = reshape_batch_dim_to_heads(self,out)
return to_out(out)
return forward
class DummyController:
def __call__(self, *args):
return args[0]
def __init__(self):
self.num_att_layers = 0
if controller is None:
controller = DummyController()
def register_recr(net_, count, place_in_unet, module_name=None):
# if net_.__class__.__name__ == 'CrossAttention':
# net_.forward = ca_forward(net_, place_in_unet)
# return count + 1
if module_name in ["attn1", "attn2"]:
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for k,net__ in net_.named_children():
count = register_recr(net__, count, place_in_unet, module_name = k)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count
Additionally, you might have to comment out the following lines in the 4th code cell of null_text_w_ptp.ipynb
:
# try:
# ldm_stable.disable_xformers_memory_efficient_attention()
# except AttributeError:
# print("Attribute disable_xformers_memory_efficient_attention() is missing")
DDIM inversion... Unexpected exception formatting exception. Falling back to standard exception Traceback (most recent call last): File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/tmp/ipykernel_585007/266642345.py", line 3, in
(image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=(0,0,200,0), verbose=True)
File "/tmp/ipykernel_585007/262494972.py", line 168, in invert
image_rec, ddim_latents = self.ddim_inversion(image_gt)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, kwargs)
File "/tmp/ipykernel_585007/262494972.py", line 125, in ddim_inversion
ddim_latents = self.ddim_loop(latent)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, *kwargs)
File "/tmp/ipykernel_585007/262494972.py", line 112, in ddim_loop
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
File "/tmp/ipykernel_585007/262494972.py", line 46, in get_noise_pred_single
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(input, kwargs)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 582, in forward
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, kwargs)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 837, in forward
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, *kwargs)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/transformer_2d.py", line 265, in forward
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(input, kwargs)
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/attention.py", line 291, in forward
class FeedForward(nn.Module):
File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2102, in showtraceback stb = self.InteractiveTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1310, in structured_traceback return FormattedTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1199, in structured_traceback return VerboseTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1052, in structured_traceback formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context, File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 978, in format_exception_as_a_whole frames.append(self.format_record(record)) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 878, in format_record frame_info.lines, Colors, self.has_colors, lvals File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 712, in lines return self._sd.lines File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 698, in lines pieces = self.included_pieces File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 649, in included_pieces pos = scope_pieces.index(self.executing_piece) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 628, in executing_piece return only( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/executing/executing.py", line 164, in only raise NotOneValueFound('Expected one value, found 0') executing.executing.NotOneValueFound: Expected one value, found 0