omerbt / TokenFlow

Official Pytorch Implementation for "TokenFlow: Consistent Diffusion Features for Consistent Video Editing" presenting "TokenFlow" (ICLR 2024)
https://diffusion-tokenflow.github.io
MIT License
1.52k stars 134 forks source link

Code snippet to reduce VRAM usage when too many frames to process. #25

Open hoveychen opened 10 months ago

hoveychen commented 10 months ago

Base on https://github.com/omerbt/TokenFlow/issues/20, I've modified the code to reduce vram usage when processing.

Usage:

Replace the register_extended_attention_pnp() function in tokenflow_utils.py with the code snippet below.


def register_extended_attention_pnp(model, injection_schedule):
    def sa_forward(self):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward_original(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h

            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []

            for frame in range(n_frames):
                out = []
                for j in range(h):
                    sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)                                            
                    out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)

            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
            return out

        def forward_extended(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h

            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            window_size = 3

            for frame in range(n_frames):
                out = []
                # sliding window to improve speed.
                window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))

                for j in range(h):
                    sim_all = []

                    for kframe in window:
                        sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)

                    sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
                    sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
                    out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)

            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)

            return out

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            n_frames = batch_size // 3

            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            q = self.to_q(x)
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)

            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                # inject unconditional
                q[n_frames:2 * n_frames] = q[:n_frames]
                k[n_frames:2 * n_frames] = k[:n_frames]
                # inject conditional
                q[2 * n_frames:] = q[:n_frames]
                k[2 * n_frames:] = k[:n_frames]

            out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
            out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
            out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])

            out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)

            return to_out(out)

        return forward

    for _, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module.attn1.forward = sa_forward(module.attn1)
            setattr(module.attn1, 'injection_schedule', [])

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)
            setattr(module, 'injection_schedule', injection_schedule)

[!Note] The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.