omerbt / TokenFlow

Official Pytorch Implementation for "TokenFlow: Consistent Diffusion Features for Consistent Video Editing" presenting "TokenFlow" (ICLR 2024)
MIT License
1.52k stars 134 forks source link

Code snippet to reduce VRAM usage when too many frames to process. #25

Open hoveychen opened 10 months ago

hoveychen commented 10 months ago

Base on, I've modified the code to reduce vram usage when processing.


Replace the register_extended_attention_pnp() function in with the code snippet below.

def register_extended_attention_pnp(model, injection_schedule):
    def sa_forward(self):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
            to_out = self.to_out

        def forward_original(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h

            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []

            for frame in range(n_frames):
                out = []
                for j in range(h):
                    sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)                                            
                    out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)

                out =, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)

            out =, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
            return out

        def forward_extended(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h

            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            window_size = 3

            for frame in range(n_frames):
                out = []
                # sliding window to improve speed.
                window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))

                for j in range(h):
                    sim_all = []

                    for kframe in window:
                        sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)

                    sim_all =, seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
                    sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
                    out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)

                out =, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)

            out =, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)

            return out

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            n_frames = batch_size // 3

            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            q = self.to_q(x)
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)

            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                # inject unconditional
                q[n_frames:2 * n_frames] = q[:n_frames]
                k[n_frames:2 * n_frames] = k[:n_frames]
                # inject conditional
                q[2 * n_frames:] = q[:n_frames]
                k[2 * n_frames:] = k[:n_frames]

            out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
            out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
            out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])

            out =[out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)

            return to_out(out)

        return forward

    for _, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module.attn1.forward = sa_forward(module.attn1)
            setattr(module.attn1, 'injection_schedule', [])

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)
            setattr(module, 'injection_schedule', injection_schedule)

[!Note] The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.