facebookresearch / ijepa

Official codebase for I-JEPA, the Image-based Joint-Embedding Predictive Architecture. First outlined in the CVPR paper, "Self-supervised learning from images with a joint-embedding predictive architecture."
Other
2.79k stars 355 forks source link

Why doesn't the model collapse? #25

Open Ber666 opened 1 year ago

Ber666 commented 1 year ago

Hi, thanks for the great work. From the paper, the only objective function is the distance between predicted and encoded representations of target patches. Why does the model not converge to a trivial solution? e.g. predicting 0s all the time. I noticed there are some cited works on this issue, but do you have an intuitive explanation for this? Thanks!

lanalex commented 1 year ago

I want to add my support to that question and even expand it. In the VicReg paper https://arxiv.org/abs/2105.04906 it was talked about a specific loss that prevents the collapse without negative mining or large batches. Here it seems a much simpler, but the default batch size seems quiet large (128) - which was referred to in previous papers that can help alleviate the collapse issue.

yuedajiong commented 1 year ago

Just from my personal limited understanding:

a few lines code to understand Barlow Twins & VicReg:
https://github.com/facebookresearch/ijepa/files/11795454/Barlow.Twins.VIC-Reg.txt

from-top-to-bottom:

  1. no other method, only parameters distrubution control.
  2. (important#1) any direct regulization, or, auxiliary task network. (in JEPA not i-JEPA, use VicReg, there are descriptions: Max-IC, Min-IC, ...)
  3. from my personal understanding, in fact, we can construct auxiliary network to mitigate this problem, eg. subnetwork-A: X -> recon_encoder -> H -> recon_decoder -> recon_loss(), train and freeze, and, subnetwork-B-like iJEPA: X -> recon_encoder -> H -> ijepa_input-side_enocder_small_version -> other ijepa_network that means, we can abstract laten representation by reconstruction, can pre-trrain and freeze, the laten representation is very small.
  4. I can not find VicReg in i-JEPA,just a little bit step 'layer_norm' after target_encoder.
  5. (important#2) BUT, I found that an interesting implementation about train: 'momentum update of target encoder'. that is, don't directly train target_encoder (while init, used copy_from_input-side_encoder), and do parameter updating from input-side encoder. that is Strongly-Correlated. still from my personal understanding, this i-JEPA task is spatial-mask, not temporal-frame prediction, both input-side encoder and target-encoder, have a certain degree of similarity/correlation. maybe this is corrected to the collapse issue. @MidoAssran
  6. the imagenet-dataset is larget enough, the diversity is enough. and big-batch and small learning-rate are helpful? if we try to sample a few images from imagenet-1k, do few-shot task, even transfer to other data-style, such as 2-d anime task, maybe encounter some other problems. the huge-version-transformer has strong memory capacity/ability, especial for daily-life natural-scene images, may mask some challenges.
lanalex commented 1 year ago

(important#2) BUT, I found that an interesting implementation about train: 'momentum update of target encoder'. that is, don't directly train target_encoder (while init, used copy_from_input-side_encoder), and do parameter updating from input-side encoder. that is Strongly-Correlated. still from my personal understanding, this i-JEPA task is spatial-mask, not temporal-frame prediction, both input-side encoder and target-encoder, have a certain degree of similarity/correction. maybe this is corrected to the collapse issue. @MidoAssran

Looking at DINOV2 I think the momentum updated of the target encoder is a similar concept compared with what is going here. I think that is the basic "motif" that helps to solve the same issue. @Ber666 I think that is the reason it doesn't collapse. The fact that the target encoder is "lagging" behind , in a similar way done in Dinov2 (conceptually)

yuedajiong commented 1 year ago

Thanks @lanalex, and clue on DINOv2 technical point: SSLMetaArch.update_teacher() from student.

ballasnicolas commented 1 year ago

Hi all,

Thanks for your interest in our work!. Different methods exist for preventing collapse in joint-embedding approaches (contrastive, non-contrastive, clustering...) In this work, we rely on asymmetric architecture between target/context encoders to prevent collapse. Specifically, we don't backpropagate the loss on the target network and update its weight through a moving average update.

This collapse prevention mechanism has been introduced by the work BYOL (https://arxiv.org/abs/2006.07733). The stop-gradient is really key to prevent collapse (see https://arxiv.org/abs/2011.10566). We found that the moving average update is useful to stabilize optimization with VIT in our experiments.

To my knowledge, it is still an open question why asymmetric architecture can prevent collapse, but some works have started exploring this question (https://arxiv.org/abs/2204.00613).

Let me know if that answer your question.

yuedajiong commented 1 year ago

Thanks @ballasnicolas , so much useful information.

I saw similar policy in RL, named 'soft-update'.

Is this 'lagging-policy' the best way? I do NOT think so.

Form my personal understanding, your primary motivation is to make prediction in laten space. (I have my doubts to LeCun's this 'explicit/unambiguous/layer-by-layer hierarchical laten' viewpoint, too.)

what's lagging? what is lagged? why lag? we think about these different representation spaces: raw image space S0 (input/target), input-side encoded laten space S1, input-side predicted laten space S2, target-side encoded laten space S3. obviously, S0_i == S0_t; S2~= S3 in loss-level. the key issue is, how to design/construct the S1 space? if use 'lagging', that lets the target-side encoder contains 2 steps: a) encode to S1, b)trend to S2.
that is: the 'lagging' lets target not only similar/correlated to S1, but help to transform to S2. in fact, I think we can NOT say: predict in S1 space, because S3 != S1, that means, the predictor includs not only predict, but also transorm.

so, is there any other 'lag' policy? because current 'lag' is very heavy for every train epoch, especial distributed train: parameters copy, mul, add. I call it 'weight&time-based lagging'. new 'lag' policy: e.g.

  1. update parametes periodically,of course, momentum can be included. I call it 'period/time-based lagging'
  2. update parameters partially. e.g. append a tail of target-side encoder. I call it 'space-based lagging'

    anyway, we just want to artificially construct certain type of 'must similar but must different'.


Is this the best way to follow 'predict on laten space'? I do NOT think so. a) original space is necessary for reconstruction especial for high-fidelity, here, the high-level laten space representation is just as the high-level constrains. the key of learning is to learn the mapping from input to target in original space. example: one-shot human-reconstution, input: LeCun's face/body image, hierarchical#1-man-style, hierarchical#2-human-form, and predict on hierarchical#2. that is diffucit to reconstuct a high-fidelity 3D LeCun. we just want to use iJEPA/JEPA to learn the constraints: logic-abstraction(man/human), spatial-relationship(human contains up-head, middle-body, and bottom-foot), ... In a words, different space abstractions have different effects, especially when there is information missing.

b) I believe that the explicit organization of hierarchical abstractions is an overly subjective introduction to human knowledge. I disagree with this part of the JEPA. I agree with that we need a 'Implicit is primary and explicit is secondary' world-model as constraints.

JulesCollenne commented 8 months ago

For people wondering how it's done in the code, it's line 332 of src/train.py :

# Step 3. momentum update of target encoder
      with torch.no_grad():
          m = next(momentum_scheduler)
          for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
              param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

Which is the implementation of BYOL's momentum mecanism. Apart from that, the target encoder never update its weights.