LinWeizheDragon / FLMR

The huggingface implementation of Fine-grained Late-interaction Multi-modal Retriever.
66 stars 4 forks source link

How to do calculate simiarity score between query and gallery item without indexing #15

Closed duyc168 closed 3 months ago

duyc168 commented 4 months ago

hi, I want to compare the retrieval results between with and without indexing, how to do calculate simiarity score between query and gallery item without indexing?

LinWeizheDragon commented 4 months ago

Hi,

You can refer to the forward function.

### Use query function to get query embeddings
      query_outputs = self.query(
            input_ids=query_input_ids,
            attention_mask=query_attention_mask,
            pixel_values=query_pixel_values,
            image_features=query_image_features,
            concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
            concat_output_from_text_encoder=query_concat_output_from_text_encoder,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        Q = query_outputs.late_interaction_output
### Use doc function to get context embeddings
        context_outputs = self.doc(
            input_ids=context_input_ids,
            attention_mask=context_attention_mask,
            pixel_values=context_pixel_values,
            image_features=context_image_features,
            concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
            concat_output_from_text_encoder=context_concat_output_from_text_encoder,
            keep_dims=True,
            return_mask=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask

### From the above code you should be able to obtain all query late interaction embeddings and all context late interaction embeddings. Note you can have (num_query, query_len, hidden_size)  and D (num_doc, doc_len, hidden_size) D_mask (num_doc, doc_len, 1)

### Call the scoring function of the model
        # Repeat each query encoding for every corresponding document.
        Q_duplicated = Q.repeat_interleave(num_docs_in_this_batch, dim=0).contiguous()

        scores = self.score(Q_duplicated, D, D_mask)

        # Use contrastive learning
        batch_size = query_input_ids.shape[0]
        scores = scores.view(-1, num_docs_in_this_batch)

Note that you may need to separate all queries and all documents into mini-batches and compute the scores chunk by chunk to reduce memory use. Computing item-wise similarity scores exhaustively can be very computationally expensive, if deemed necessary in your research

Another note is that by setting num_document_to_retrieve=500, you will be able to approach the best performance when using indexing. https://github.com/LinWeizheDragon/FLMR/blob/3e8d14a3c33d46c8c7c295f2d0cce56a3fcccc70/flmr/searching.py#L41

LinWeizheDragon commented 4 months ago

Hi, just to let you know that a finetuning script is now available at https://github.com/LinWeizheDragon/FLMR?tab=readme-ov-file#new-finetune-the-preflmr-model-on-downstream-datasets