google / prompt-to-prompt

Apache License 2.0
3.07k stars 285 forks source link

Getting error in null_text_w_ptp #49

Open nrasiwas opened 1 year ago

nrasiwas commented 1 year ago

DDIM inversion... Unexpected exception formatting exception. Falling back to standard exception Traceback (most recent call last): File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/tmp/ipykernel_585007/266642345.py", line 3, in (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=(0,0,200,0), verbose=True) File "/tmp/ipykernel_585007/262494972.py", line 168, in invert image_rec, ddim_latents = self.ddim_inversion(image_gt) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/tmp/ipykernel_585007/262494972.py", line 125, in ddim_inversion ddim_latents = self.ddim_loop(latent) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, *kwargs) File "/tmp/ipykernel_585007/262494972.py", line 112, in ddim_loop noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) File "/tmp/ipykernel_585007/262494972.py", line 46, in get_noise_pred_single noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 582, in forward File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 837, in forward File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/transformer_2d.py", line 265, in forward File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/diffusers/models/attention.py", line 291, in forward class FeedForward(nn.Module): File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2102, in showtraceback stb = self.InteractiveTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1310, in structured_traceback return FormattedTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1199, in structured_traceback return VerboseTB.structured_traceback( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1052, in structured_traceback formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context, File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 978, in format_exception_as_a_whole frames.append(self.format_record(record)) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 878, in format_record frame_info.lines, Colors, self.has_colors, lvals File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/IPython/core/ultratb.py", line 712, in lines return self._sd.lines File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 698, in lines pieces = self.included_pieces File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 649, in included_pieces pos = scope_pieces.index(self.executing_piece) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/utils.py", line 145, in cached_property_wrapper value = obj.dict[self.func.name] = self.func(obj) File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/stack_data/core.py", line 628, in executing_piece return only( File "/home/nras/miniconda3/envs/py3.8/lib/python3.8/site-packages/executing/executing.py", line 164, in only raise NotOneValueFound('Expected one value, found 0') executing.executing.NotOneValueFound: Expected one value, found 0

jingweim commented 1 year ago

Check this out: https://github.com/google/prompt-to-prompt/issues/44#issuecomment-1499949385

guillermogotre commented 1 year ago

Similarly to #44 I've managed to make it work with diffusers==0.16.1, transformers==4.29.2 and torch==1.12.1 modifying the code for register_attention_control(model, controller) in ptp_utils.py :

def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def reshape_heads_to_batch_dim(self, tensor):
            batch_size, seq_len, dim = tensor.shape
            head_size = self.heads
            tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
            tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
            return tensor

        def reshape_batch_dim_to_heads(self, tensor):
            batch_size, seq_len, dim = tensor.shape
            head_size = self.heads
            tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
            tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
            return tensor

        # def forward(x, context=None, mask=None):
        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
            x = hidden_states
            context = encoder_hidden_states
            mask = attention_mask

            batch_size, sequence_length, dim = x.shape
            h = self.heads
            q = self.to_q(x)
            is_cross = context is not None
            context = context if is_cross else x
            k = self.to_k(context)
            v = self.to_v(context)
            q = reshape_heads_to_batch_dim(self,q)
            k = reshape_heads_to_batch_dim(self,k)
            v = reshape_heads_to_batch_dim(self,v)

            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

            if mask is not None:
                mask = mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            attn = controller(attn, is_cross, place_in_unet)
            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = reshape_batch_dim_to_heads(self,out)
            return to_out(out)

        return forward

    class DummyController:

        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet, module_name=None):
        # if net_.__class__.__name__ == 'CrossAttention':
        #     net_.forward = ca_forward(net_, place_in_unet)
        #     return count + 1
        if module_name in ["attn1", "attn2"]:
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for k,net__ in net_.named_children():
                count = register_recr(net__, count, place_in_unet, module_name = k)
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count

Additionally, you might have to comment out the following lines in the 4th code cell of null_text_w_ptp.ipynb:

# try:
#     ldm_stable.disable_xformers_memory_efficient_attention()
# except AttributeError:
#     print("Attribute disable_xformers_memory_efficient_attention() is missing")