ma787639046 / bowdpr

Codebase for [Paper] Pre-training with Bag-of-Word Prediction for Dense Passage Retrieval
Apache License 2.0
11 stars 1 forks source link

CLS Representation Collapse #2

Open sversage opened 5 months ago

sversage commented 5 months ago

Trying this approach I observe that the cls representations seem to collapse after time where semantic similarity is no better than random - did you observe when using MLM + BoW only (no AE or AR task) this seems to occur?

ma787639046 commented 5 months ago

Hi sversage,

Sorry for the late reply. The pre-training example here is using MLM + BoW only, without AE/AR.

The training of MLM + BoW is stable in my experiments. Could you please tell me your training settings in detail?

Note that the BoW is a pre-training object for the initialization of PLM, it cannot encode sentence embeddings without downstream fine-tuning. Did you directly use the pre-trained PLM for computing semantic similarity?

sversage commented 5 months ago

No problem -- I was comparing the embeddings of an MLM only model (average of tokens) to an MLM + Llara (cls) to see if the cls embedding quality improved w/ BoW.

My thought process was that while fine tuning would definitely improve performance, that the MLM + BoW would more or less make the cls token embedding as good as the average token embedding quality since the BoW's objective function would lead to the two being brought closer together [cls emb == average of seq tokens]

What do you think?

sversage commented 5 months ago

found another repo that has a similar idea which seems to come across the same issue - (more or less if the AE is akin to the BoW) https://github.com/FlagOpen/FlagEmbedding/issues/538

ma787639046 commented 5 months ago

No problem -- I was comparing the embeddings of an MLM only model (average of tokens) to an MLM + Llara (cls) to see if the cls embedding quality improved w/ BoW.

My thought process was that while fine tuning would definitely improve performance, that the MLM + BoW would more or less make the cls token embedding as good as the average token embedding quality since the BoW's objective function would lead to the two being brought closer together [cls emb == average of seq tokens]

What do you think?

Hi, here is my opinion.

BoW is a pre-task for the pre-training stage, which is a prediction task via multi-label Cross Entropy. Usually, we use cosine similarity/dot product to compute the distances of the embeddings. There is a mismatch between the prediction task at the pre-training stage and the cosine similarity/dot product at the inferencing stage. Thus, BoW is only for PLM initialization, which helps the encoders adapt well for downstream fine-tuning.

I'm also wondering if the MLM model (only with pre-training) + average pooling could perform semantic search or retrieval directly without any fine-tuning. Because the mismatch between MLM and similarity computation still exists. We still need the contrastive loss in the fine-tuning stage to pull together the positive pairs and push away the negatives.

ma787639046 commented 5 months ago

found another repo that has a similar idea which seems to come across the same issue - (more or less if the AE is akin to the BoW) FlagOpen/FlagEmbedding#538

I'm a little bit confused by the meaning of collapse. Does it mean that the BoW loss (or AE loss from Llara) converge quickly and do not go down?