Open FriedRonaldo opened 8 months ago
Hi, @MidoAssran . It would be great if one of the authors could reply the issue to make future readers understand the work better. Thanks!
If there's substantial information leakage due to this unintended mask sampling behavior, it could compromise the model's temporal learning capabilities by simplifying the learning task. Definitely interested in learning more.
Was anyone able to achieve training and evaluating for the vjepa code?
Are there any updates on this issue? I am trying to understand the vjepa masking strategy: i) the number of non-masked patches + number of masked-patches do not match the total number of patches ii) From the two types of masks (short and long-range), seems to choose only the first of them in the predictor's forward.
Any help / pointers are appreciated.
Hi, I read the paper JEPA and it is an effective way to learn temporal information better than other works like VideoMAE and UMT.
I have a question about the mask sampling.
To be clear, I do not mean to review or criticize the paper, but I want to reproduce the work exactly.
Question 01) When I instantiate a mask generator and then sample a mask, it sometimes masks only the first N frames.
For example, the source code below describes the situation.
it outputs
In some cases like mask_enc[-1] and mask_enc[-4], the mask is applied only to the first frame. (There are 2 frames and 16 patches for each frame, then, the indices of [[ 4, 5, 8, 9, 10, 11, 14, 15]] can mask the first frame only -- because the index under 16 is included in the first frame.)
In this case, for some batches, the model seems to use the part of the frames (ex. 4 masked frames out of 8 frames) and is required to reconstruct the entire patches only with first some patches in some frames. (ex. reconstruct 8 frames using 4 masked frames)
Is my analysis correct? If so, it might not be the same as the description of the paper that says the mask is the same for all frames.
In this case, the masking strategy does not work as the intention to limit information leakage.
Question 02) The sum of the visible and invisible masks seems not to be the same as the total number of patches.
When I print the shape of each mask, I get the output like below:
There are 32 patches (2 frames * 16 patches for each frame = 32) but the sum of the lengths is less than the total patch counts.
Discussion
The second question might not be that problematic. It uses the part of the visible patches for each sample to reconstruct the part of the input video. Because partial reconstruction in MAE is shown to be effective in the paper [1]
[1] CrossMAE: Rethinking Patch Dependence for Masked Autoencoders
Approach (if the analysis is correct and the behavior is not intended)
However, the first question can affect the performance because the masking method aims to block the information leakage between the frames, specifically, preventing the model from copying the near patches at the different frames.
To resolve the problem, I think the masking block should be sampled for a single frame and repeated along the time axis with an offset (the number of patches in each frame).
I hope the discussion improves the clarity of the source code and the paper.
Thanks.
Update
The source code below can be a way to fix the mask sampling method.
For the sanity check, I run the code without "tmp = torch.nonzero(tmp.flatten()).squeeze()".
The outputs are like: