princeton-nlp / SimCSE

[EMNLP 2021] SimCSE: Simple Contrastive Learning of Sentence Embeddings https://arxiv.org/abs/2104.08821
MIT License
3.41k stars 515 forks source link

computing alignment and uniformity #120

Closed lephong closed 2 years ago

lephong commented 2 years ago

I'm following Wang and Isola to compute alignment and uniformity (using their given code in Fig 5, http://proceedings.mlr.press/v119/wang20k/wang20k.pdf) to reproduce Fig 2 in your paper but fail. What I saw is that the alignment decreases whereas the uniformity is almost unchanged, which is completely different from Fig 2. Details are below.

To compute alignment and uniformity, I changed line 66-79 file SimCSE/blob/main/SentEval/senteval/sts.py by adding the code from Wang and Isola:

            ...
            input1, input2, gs_scores = self.data[dataset]
            all_enc1 = []
            all_enc2 = []
            for ii in range(0, len(gs_scores), params.batch_size):
                batch1 = input1[ii:ii + params.batch_size]
                batch2 = input2[ii:ii + params.batch_size]

                # we assume get_batch already throws out the faulty ones
                if len(batch1) == len(batch2) and len(batch1) > 0:
                    enc1 = batcher(params, batch1)
                    enc2 = batcher(params, batch2)

                    all_enc1.append(enc1.detach())
                    all_enc2.append(enc2.detach())
                    ...

             def _norm(x, eps=1e-8): 
                xnorm = torch.linalg.norm(x, dim=-1)
                xnorm = torch.max(xnorm, torch.ones_like(xnorm) * eps)
                return x / xnorm.unsqueeze(dim=-1)

            # from Wang and Isola (with a bit of modification)
            # only consider pairs with gs > 4 (from footnote 3)
            def _lalign(x, y, ok, alpha=2):
                return ((_norm(x) - _norm(y)).norm(dim=1).pow(alpha) * ok).sum() / ok.sum()

            def _lunif(x, t=2):
                sq_pdist = torch.pdist(_norm(x), p=2).pow(2)
                return sq_pdist.mul(-t).exp().mean().log()

            ok = (torch.Tensor(gs_scores) > 4).int()
            align = _lalign(
                torch.cat(all_enc1), 
                torch.cat(all_enc2), 
                ok).item()

            # consider all sentences (from footnote 3)
            unif = _lunif(torch.cat(all_enc1 + all_enc2)).item()
            logging.info(f'align {align}\t\t uniform {unif}')

The output (which also shows spearman on stsb dev set) is

align 0.2672557830810547 uniform -2.5320491790771484 'eval_stsb_spearman': 0.6410360622426501, 'epoch': 0.01
align 0.2519586384296417 uniform -2.629746913909912 'eval_stsb_spearman': 0.6859433315879646, 'epoch': 0.02
align 0.2449202835559845 uniform -2.5870673656463623 'eval_stsb_spearman': 0.7198291431689111, 'epoch': 0.02
align 0.22248655557632446 uniform -2.557053565979004 'eval_stsb_spearman': 0.7538674335025006, 'epoch': 0.03
align 0.22624073922634125 uniform -2.6622540950775146 'eval_stsb_spearman': 0.7739112284380941, 'epoch': 0.04
align 0.22583454847335815 uniform -2.5768041610717773 'eval_stsb_spearman': 0.7459814500897265, 'epoch': 0.05
align 0.22845414280891418 uniform -2.5601420402526855 'eval_stsb_spearman': 0.7683573046863201, 'epoch': 0.06
align 0.22689573466777802 uniform -2.560364007949829 'eval_stsb_spearman': 0.7766837072148098, 'epoch': 0.06
align 0.22807720303535461 uniform -2.5539987087249756 'eval_stsb_spearman': 0.7692866256106997, 'epoch': 0.07
align 0.20026598870754242 uniform -2.50628399848938 'eval_stsb_spearman': 0.7939010002048291, 'epoch': 0.08
align 0.20466476678848267 uniform -2.535121440887451 'eval_stsb_spearman': 0.8011027122797894, 'epoch': 0.09
align 0.2030458152294159 uniform -2.5547776222229004 'eval_stsb_spearman': 0.8044623693996088, 'epoch': 0.1
align 0.20119303464889526 uniform -2.5325350761413574 'eval_stsb_spearman': 0.8070404405714893, 'epoch': 0.1
align 0.19329915940761566 uniform -2.488903522491455 'eval_stsb_spearman': 0.8220311448535872, 'epoch': 0.11
align 0.19556573033332825 uniform -2.5273373126983643 'eval_stsb_spearman': 0.8183500898254208, 'epoch': 0.12
align 0.19112755358219147 uniform -2.4959402084350586 'eval_stsb_spearman': 0.8146496522216178, 'epoch': 0.13
align 0.18491695821285248 uniform -2.4762508869171143 'eval_stsb_spearman': 0.8088527080054781, 'epoch': 0.14
align 0.19815796613693237 uniform -2.5905373096466064 'eval_stsb_spearman': 0.8333401056438776, 'epoch': 0.14
align 0.1950838416814804 uniform -2.4894299507141113 'eval_stsb_spearman': 0.8293951990138778, 'epoch': 0.15
align 0.19777807593345642 uniform -2.5985066890716553 'eval_stsb_spearman': 0.8268435050866446, 'epoch': 0.16
align 0.2016373723745346 uniform -2.616013765335083 'eval_stsb_spearman': 0.8199602019842832, 'epoch': 0.17
align 0.19906719028949738 uniform -2.57528018951416 'eval_stsb_spearman': 0.8094202934650283, 'epoch': 0.18
align 0.18731220066547394 uniform -2.517271041870117 'eval_stsb_spearman': 0.8231122818777513, 'epoch': 0.18
align 0.18802008032798767 uniform -2.508246421813965 'eval_stsb_spearman': 0.8248523275594679, 'epoch': 0.19
align 0.20015984773635864 uniform -2.4563515186309814 'eval_stsb_spearman': 0.8061084765791668, 'epoch': 0.2
align 0.2015877515077591 uniform -2.5121841430664062 'eval_stsb_spearman': 0.8113328705761889, 'epoch': 0.21
align 0.20187602937221527 uniform -2.5167288780212402 'eval_stsb_spearman': 0.8124173161634701, 'epoch': 0.22
align 0.20096932351589203 uniform -2.5201926231384277 'eval_stsb_spearman': 0.8127754107163266, 'epoch': 0.22
align 0.19966433942317963 uniform -2.5182201862335205 'eval_stsb_spearman': 0.8152261579570365, 'epoch': 0.23
align 0.19897222518920898 uniform -2.557129383087158 'eval_stsb_spearman': 0.8169452712415308, 'epoch': 0.24
...

We can see that alignment drops from 0.26 to less than 0.20 whereas uniformity is still around -2.55. It means that reducing alignment is key, not uniformity. This trend is completely different from Fig 2.

Did you also use the code from Wang and Isola like I did? If possible, could you please provide the code for reproducing alignment and uniformity?

gaotianyu1350 commented 2 years ago

The uniformity will drop very fast from the beginning. Can you specify what is your initialization and what's the stride to calculate the uniformity?

lephong commented 2 years ago

I didn't change anything else except adding some lines to calculate the alignment and uniformity (as mentioned before). More specifically, from run_unsup_example.sh


python train.py \
    --model_name_or_path bert-base-uncased \
    --train_file data/wiki1m_for_simcse.txt \
    --output_dir result/my-unsup-simcse-bert-base-uncased \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --learning_rate 3e-5 \
    --max_seq_length 32 \
    --evaluation_strategy steps \
    --metric_for_best_model stsb_spearman \
    --load_best_model_at_end \
    --eval_steps 125 \
    --pooler_type cls \
    --mlp_only_train \
    --overwrite_output_dir \
    --temp 0.05 \
    --do_train \
    --do_eval \
    --fp16 \

For initialisation, I didn't change random seed. So I guess it's 42 from huggingface (don't know, maybe wrong).

gaotianyu1350 commented 2 years ago

If I understand correctly, you calculate the alignment/uniformity every 125 step (the same as validation). In the original paper, we calculate every 10 step, because as I mentioned, the uniformity drops very fast at the beginning of the training.

lephong commented 2 years ago

ah, so you mean every 10 update steps / batches? I thought it was every 10 * 125 batches.

But even if that's the case, I'm not sure if figure 2 provides a good explanation here because after 125 steps (or 12 little red stars in figure 2), the accuracy (on STSB dev) is only around 60%, which is much lower than 82.5% in the paper. So, I think you can use fig 2 to explain what happens in the very first training phase, but then, the gap of 82.5 - 60 = 22.5% is not explained.

gaotianyu1350 commented 2 years ago

You can use Figure 3 as a reference (although it's not a rigorous comparison because we didn't put CLS BERT representation, which is the initialization for SimCSE into the figure), and it's the uniformity that makes a huge difference.

lephong commented 2 years ago

that makes sense. thanks