Closed duyc168 closed 5 months ago
Hi,
You can refer to the forward
function.
### Use query function to get query embeddings
query_outputs = self.query(
input_ids=query_input_ids,
attention_mask=query_attention_mask,
pixel_values=query_pixel_values,
image_features=query_image_features,
concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
concat_output_from_text_encoder=query_concat_output_from_text_encoder,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
Q = query_outputs.late_interaction_output
### Use doc function to get context embeddings
context_outputs = self.doc(
input_ids=context_input_ids,
attention_mask=context_attention_mask,
pixel_values=context_pixel_values,
image_features=context_image_features,
concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
concat_output_from_text_encoder=context_concat_output_from_text_encoder,
keep_dims=True,
return_mask=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
### From the above code you should be able to obtain all query late interaction embeddings and all context late interaction embeddings. Note you can have (num_query, query_len, hidden_size) and D (num_doc, doc_len, hidden_size) D_mask (num_doc, doc_len, 1)
### Call the scoring function of the model
# Repeat each query encoding for every corresponding document.
Q_duplicated = Q.repeat_interleave(num_docs_in_this_batch, dim=0).contiguous()
scores = self.score(Q_duplicated, D, D_mask)
# Use contrastive learning
batch_size = query_input_ids.shape[0]
scores = scores.view(-1, num_docs_in_this_batch)
Note that you may need to separate all queries and all documents into mini-batches and compute the scores chunk by chunk to reduce memory use. Computing item-wise similarity scores exhaustively can be very computationally expensive, if deemed necessary in your research
Another note is that by setting num_document_to_retrieve=500
, you will be able to approach the best performance when using indexing.
https://github.com/LinWeizheDragon/FLMR/blob/3e8d14a3c33d46c8c7c295f2d0cce56a3fcccc70/flmr/searching.py#L41
Hi, just to let you know that a finetuning script is now available at https://github.com/LinWeizheDragon/FLMR?tab=readme-ov-file#new-finetune-the-preflmr-model-on-downstream-datasets
hi, I want to compare the retrieval results between with and without indexing, how to do calculate simiarity score between query and gallery item without indexing?