facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.71k stars 255 forks source link

Question about the mask sampling #50

Open FriedRonaldo opened 8 months ago

FriedRonaldo commented 8 months ago

Hi, I read the paper JEPA and it is an effective way to learn temporal information better than other works like VideoMAE and UMT.

I have a question about the mask sampling.

To be clear, I do not mean to review or criticize the paper, but I want to reproduce the work exactly.

Question 01) When I instantiate a mask generator and then sample a mask, it sometimes masks only the first N frames.

For example, the source code below describes the situation.

mg = BlockMaskGenerator(aspect_ratio=(0.75, 1.5), npred=8, spatial_pred_mask_scale=(0.15, 0.15), temporal_pred_mask_scale=(1., 1.), max_context_frames_ratio=1.0, image_size=(64, 64), num_frames=2, patch_size=(16, 16), temporal_stride=1)
mask_enc, mask_pred = mg(16)
print(mask_enc)

it outputs

tensor([[ 6,  7,  8, 11, 12, 13, 22, 23],
        [ 2,  3,  4,  5, 10, 11, 15, 18],
        [ 0,  3,  6,  7,  8, 11, 15, 16],
        [ 4,  5, 12, 13, 20, 21, 28, 29],
        [ 2,  3, 10, 11, 12, 13, 14, 15],
        [ 3,  4,  5, 12, 13, 14, 15, 19],
        [ 3,  8, 12, 13, 14, 15, 19, 24],
        [ 0,  3,  8,  9, 15, 16, 19, 24],
        [ 7,  8, 11, 12, 13, 23, 24, 27],
        [ 4,  7,  8, 12, 13, 20, 23, 24],
        [ 2,  3,  4,  8, 18, 19, 20, 24],
        [ 3,  4,  8,  9, 10, 11, 19, 20],
        [ 0,  7,  8,  9, 10, 11, 12, 15],
        [ 0,  1,  4,  7, 11, 12, 13, 16],
        [ 6,  7, 11, 12, 15, 22, 23, 27],
        [ 4,  5,  8,  9, 10, 11, 14, 15]])

In some cases like mask_enc[-1] and mask_enc[-4], the mask is applied only to the first frame. (There are 2 frames and 16 patches for each frame, then, the indices of [[ 4, 5, 8, 9, 10, 11, 14, 15]] can mask the first frame only -- because the index under 16 is included in the first frame.)

In this case, for some batches, the model seems to use the part of the frames (ex. 4 masked frames out of 8 frames) and is required to reconstruct the entire patches only with first some patches in some frames. (ex. reconstruct 8 frames using 4 masked frames)

Is my analysis correct? If so, it might not be the same as the description of the paper that says the mask is the same for all frames.

3D Multi-Block Masking. We use a simple 3D extension of the block masking strategy employed for images (Bao et al., 2021). Given a video, we sample several (possibly overlapping) spatially continuous blocks with various aspect ratios and take their union to construct a single mask. This spatial mask is then repeated across the entire temporal dimension. Masking a large continuous block that covers the full temporal dimension limits information leakage due to the spatial and temporal redundancy of videos, and results in a harder prediction task (Tong et al., 2022).

In this case, the masking strategy does not work as the intention to limit information leakage.

Question 02) The sum of the visible and invisible masks seems not to be the same as the total number of patches.

When I print the shape of each mask, I get the output like below:

print(mask_enc.shape)
print(mask_pred.shape)

torch.Size([16, 8])
torch.Size([16, 16])

There are 32 patches (2 frames * 16 patches for each frame = 32) but the sum of the lengths is less than the total patch counts.

Discussion

The second question might not be that problematic. It uses the part of the visible patches for each sample to reconstruct the part of the input video. Because partial reconstruction in MAE is shown to be effective in the paper [1]

[1] CrossMAE: Rethinking Patch Dependence for Masked Autoencoders

Approach (if the analysis is correct and the behavior is not intended)

However, the first question can affect the performance because the masking method aims to block the information leakage between the frames, specifically, preventing the model from copying the near patches at the different frames.

To resolve the problem, I think the masking block should be sampled for a single frame and repeated along the time axis with an offset (the number of patches in each frame).

I hope the discussion improves the clarity of the source code and the paper.

Thanks.

Update

The source code below can be a way to fix the mask sampling method.

        collated_masks_pred, collated_masks_enc = [], []
        min_keep_enc = min_keep_pred = self.duration * self.height * self.width
        for _ in range(batch_size):

            empty_context = True
            while empty_context:

                mask_e = torch.ones((1, self.height, self.width), dtype=torch.int32)
                for _ in range(self.npred):
                    mask_e *= self._sample_block_mask(p_size)
                mask_e = mask_e.flatten()

                mask_p = torch.argwhere(mask_e == 0).squeeze()
                mask_e = torch.nonzero(mask_e).squeeze()

                empty_context = (len(mask_e) == 0)
                if not empty_context:
                    min_keep_pred = min(min_keep_pred, len(mask_p))
                    min_keep_enc = min(min_keep_enc, len(mask_e))
                    collated_masks_pred.append(mask_p)
                    collated_masks_enc.append(mask_e)

        if self.max_keep is not None:
            min_keep_enc = min(min_keep_enc, self.max_keep)

        # --
        return self._truncate_mask(collated_masks_enc, min_keep_enc), self._truncate_mask(collated_masks_pred, min_keep_pred)

    def _truncate_mask(self, masks, min_keep):
        result = []
        for cm in masks:
            # choice min_keep items randomly
            idx = torch.randperm(len(cm))[:min_keep]
            cm = cm[idx]
            tmp = torch.zeros((1, self.height, self.width), dtype=torch.int32)
            tmp.flatten()[cm] = 1
            tmp = tmp.repeat(self.duration, 1, 1)
            tmp = torch.nonzero(tmp.flatten()).squeeze()
            result.append(tmp)
        return torch.utils.data.default_collate(result)

For the sanity check, I run the code without "tmp = torch.nonzero(tmp.flatten()).squeeze()".

The outputs are like:

image
FriedRonaldo commented 8 months ago

Hi, @MidoAssran . It would be great if one of the authors could reply the issue to make future readers understand the work better. Thanks!

rozgo commented 8 months ago

If there's substantial information leakage due to this unintended mask sampling behavior, it could compromise the model's temporal learning capabilities by simplifying the learning task. Definitely interested in learning more.

gozy0 commented 2 weeks ago

Was anyone able to achieve training and evaluating for the vjepa code?

Are there any updates on this issue? I am trying to understand the vjepa masking strategy: i) the number of non-masked patches + number of masked-patches do not match the total number of patches ii) From the two types of masks (short and long-range), seems to choose only the first of them in the predictor's forward.

Any help / pointers are appreciated.