williamyang1991 / FRESCO

[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation
https://www.mmlab-ntu.com/project/fresco/
Other
734 stars 71 forks source link

Meaning of unet_chunk_size in FRESCOAttnProcessor2_0 codes? #41

Closed Yuezhengrong closed 7 months ago

Yuezhengrong commented 7 months ago

The 'video_len' in the source code is equal to 'batch_size // unet_chunk_size'. But if 8 frames are fed in at once, isn't 'batch_size' equal to 'video_len'?

        '''for efficient cross-frame attention'''
        if self.controller and self.controller.use_cfattn and (not crossattn):
            video_length = key.size()[0] // self.unet_chunk_size  # BC // C = B
            former_frame_index = [0] * video_length  # pervious frame index
            attn_mask = None  # get attn_mask from controller
            if self.controller.attn_mask is not None:
                for m in self.controller.attn_mask:
                    if m.shape[1] == key.shape[1]:
                        attn_mask = m
            # get key and value in former_frame_index range
            # BC * HW * 8D --> B * C * HW * 8D
            key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
            # B * C * HW * 8D --> B * C * HW * 8D
            if attn_mask is None:
                key = key[:, former_frame_index]
            else:
                key = repeat(key[:, attn_mask], "b d c -> b f d c", f=video_length)
            # B * C * HW * 8D --> BC * HW * 8D 
            key = rearrange(key, "b f d c -> (b f) d c").detach()
            value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
            if attn_mask is None:
                value = value[:, former_frame_index]
            else:
                value = repeat(value[:, attn_mask], "b d c -> b f d c", f=video_length)              
            value = rearrange(value, "b f d c -> (b f) d c").detach()

        # BC * HW * 8D --> BC * HW * 8 * D --> BC * 8 * HW * D
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # BC * 8 * HW2 * D
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # BC * 8 * HW2 * D2
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
williamyang1991 commented 7 months ago

Due to the classifier-free guidance, we will duplicate each batch. So when classifier-free guidance is enabled, unet_chunk_size=2 otherwise, unet_chunk_size=1

In you case, when running to attention layers, batch_size=16 rather than 8

https://github.com/williamyang1991/FRESCO/blob/871d2851a93690430896e0e79f98d2aa0070545a/src/diffusion_hacked.py#L163

https://github.com/williamyang1991/FRESCO/blob/871d2851a93690430896e0e79f98d2aa0070545a/src/pipe_FRESCO.py#L182

Yuezhengrong commented 7 months ago

由于无分类器指导,我们将复制每个批处理。 因此,当启用无分类器引导时,unet_chunk_size=2, 否则unet_chunk_size=1

在您的情况下,当运行到注意层时,batch_size=16 而不是 8

https://github.com/williamyang1991/FRESCO/blob/871d2851a93690430896e0e79f98d2aa0070545a/src/diffusion_hacked.py#L163

https://github.com/williamyang1991/FRESCO/blob/871d2851a93690430896e0e79f98d2aa0070545a/src/pipe_FRESCO.py#L182

I understand. It's CFG! Thanks for the answer!