texttron / tevatron

Tevatron - A flexible toolkit for neural retrieval research and development.
http://tevatron.ai
Apache License 2.0
474 stars 90 forks source link

Question about training procedure of retriever #140

Open dayuyang1999 opened 1 month ago

dayuyang1999 commented 1 month ago

Dear authors,

I was trying to reproduce repllama/repmistral and to understand the logic of the code.

1. Why use cross-entropy loss in contrastive learning?

in this file /usa/dayu/Table_similarity/tevatron/src/tevatron/retriever/modeling/encoder.py.

I have questions about the forward function.

Here cross-entropy loss function is used. The cross-entropy loss focuses on the probabilities assigned to the true classes of the samples. It doesn't directly account for the probabilities assigned to the incorrect (negative) classes in the loss calculation.

However, based on some materials I have read. Training a retriever needs to use contrastive loss, which pushes negatives away and pulls positive closer to the anchor.

For example, in the triplet loss, it explicitly penalty on the negative distance between anchor and negatives. While Cross-entropy ignore this part.

image

2. How "in-batch" negative is considered? What if we have >1 positive?

Based on the code I saw In the forward function.

            target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
            target = target * (p_reps.size(0) // q_reps.size(0))

and the code under /usa/dayu/Table_similarity/tevatron/src/tevatron/retriever/dataset.py.

        query = group['query']
        group_positives = group['positive_passages']
        group_negatives = group['negative_passages']

        formated_query = format_query(query, self.data_args.query_prefix)
        formated_passages = []

        if self.data_args.positive_passage_no_shuffle:
            pos_psg = group_positives[0]
        else:
            pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]

        formated_passages.append(format_passage(pos_psg['text'], pos_psg['title'], self.data_args.passage_prefix))

       ....

        for neg_psg in negs:
            formated_passages.append(format_passage(neg_psg['text'], neg_psg['title'], self.data_args.passage_prefix))

        return formated_query, formated_passages

If I understand correctly, assuming each query can have and only have 1 positive passage (must = 1), and negative_size numbers of negatives.

assuming negative_size = 2.

the positive index will be = [0, 3, 6, 9,...] . all others are negatives.

If I understand correctly, where does the in-batch negatives? And what if we have >1 positive passages?

Thanks!

dayuyang1999 commented 1 month ago

additionally

what is the intuition behind using tempreture?

            loss = self.compute_loss(scores / self.temperature, target)

I saw some discussion under https://github.com/FlagOpen/FlagEmbedding/issues/402

The standard cosine similarity ranges from -1 to 1. A higher temperature coefficient will make the model's score range closer to the standard range.

When I use bge-en-v1.5, I found the cosine similarity score typically ranging between 0.4 and 1. (So two random documents will have 0.4 similarity which is kind of counter-intuititve)

So why people use tempreture to "reshape" the similarity distribution? Why not keep it ranges from -1 to 1, which seems to be more intuitive and clear for revealing the negative/positive relationship between documents?

dayuyang1999 commented 1 month ago

Another question about the implementation of infoNCE loss.

Seems different in the way constructing target comparing with original implementation in MoCo paper. image

            scores = self.compute_similarity(q_reps, p_reps)
            scores = scores.view(q_reps.size(0), -1) # (N_q_reps, N_p_reps)  = (N, C)

            target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) # (0, 1, 2, ..., N_q_reps)
            # N_sampels = 1 (positive) +N_negatives
            target = target * (p_reps.size(0) // q_reps.size(0)) # (0, N_sampels, 2*N_sampels, ...N_p_reps * N_sampels) # (N)

            loss = self.compute_loss(scores / self.temperature, target)