ControlNet / MARLIN

[CVPR] MARLIN: Masked Autoencoder for facial video Representation LearnINg
https://openaccess.thecvf.com/content/CVPR2023/html/Cai_MARLIN_Masked_Autoencoder_for_Facial_Video_Representation_LearnINg_CVPR_2023_paper
Other
222 stars 20 forks source link

The implementation is inconsistent with the paper? #23

Closed imxtx closed 9 months ago

imxtx commented 9 months ago

Paper:

image

Given an input video (having a dimension 3 × 16 × 224 ×224), the cube embedding layer generates 8 × 14 × 14 3D tokens of dimension 768 to preserve spatio-temporal patterns.

In the released source code, 2 x 16 x 16 is the kernel size instead of the size of the 3D tokens:

class PatchEmbedding3d(nn.Module):
    def __init__(
        self,
        input_size: Shape,
        patch_size: Union[int, Shape],
        embedding: int,  # embed_dim
        strides: Optional[Union[int, Shape]] = None,
        build_normalization: Optional[ModuleFactory] = None,
    ):
        super().__init__()
        # channel, time, height, width
        c, t, h, w = input_size
        #  patch_time, patch_height, patch_width
        pt, ph, pw = (
            (patch_size, patch_size, patch_size)
            if type(patch_size) is int  # cube
            else patch_size
        )

        # configure the strides for conv3d
        if strides is None:
            # no specified means no overlap and gap between patches
            strides = (pt, ph, pw)
        elif type(strides) is int:
            # transform the side length of strides to 3D
            strides = (strides, strides, strides)

        # [Batch, 3, 16, 224, 224] -> [Batch, 768, 8, 14, 14]
        self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides)
        self.has_norm = build_normalization is not None
        if self.has_norm:
            self.normalization = build_normalization()
        self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d")

Code for testing the 3D convolution:

model = nn.Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
x = torch.rand(10, 3, 16, 224, 224, requires_grad=True)
output = model(x)
print(output.shape)

Output:

torch.Size([10, 768, 8, 14, 14])
ControlNet commented 9 months ago

Hi, the "8 x 14 x 14 3D tokens" mentioned in the paper means there are (8 x 14 x 14 = 1568) tokens in the sequence, rather than the shape of each token is (8 x 14 x 14).

So in the next line, we highlighted the dimension is (2 x 16 x 16). Hope this can answer your question.

imxtx commented 9 months ago

Hi, the "8 x 14 x 14 3D tokens" mentioned in the paper means there are (8 x 14 x 14 = 1568) tokens in the sequence, rather than the shape of each token is (8 x 14 x 14).

So in the next line, we highlighted the dimension is (2 x 16 x 16). Hope this can answer your question.

I see, that's why I think the description "The cube embedding layer generates ..." is not correct. It generates something means it outputs something. But the output of the 3D embedding layer is of shape [Batch, 8*14*14, 768]:

# part of the __init__ method of PatchEmbedding3d class
# [Batch, 3, 16, 224, 224] -> [Batch, 768, 8, 14, 14]
self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides)
self.has_norm = build_normalization is not None
if self.has_norm:
    self.normalization = build_normalization()
# [Batch, 768, 8, 14, 14] -> [Batch, 8*14*14, 768]
self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d")

The code of the Encoder:

# ...
self.embed_dim = embed_dim
self.patch_embedding = PatchEmbedding3d(
    input_size=(3, n_frames, img_size, img_size),  # [3,16,Height,Width]
    patch_size=(tubelet_size, patch_size, patch_size),  # [2,16,16]
    embedding=embed_dim,
)  # output size [Batch, 8*14*14, 768]
# ...
ControlNet commented 9 months ago

I see, that's why I think the description "The cube embedding layer generates ..." is not correct. It generates something means it outputs something. But the output of the 3D embedding layer is of shape [Batch, 81414, 768]:

In ViT (no matter 2D patch or 3D cube), we can use Conv layer to embed the tokens into vectors (here it is 768-d vector). With proper striding and kernel size, it will be equivalent to manually reshape it then use a linear layer to embed it.