RaviSoji / plda

Probabilistic Linear Discriminant Analysis & classification, written in Python.
https://ravisoji.com
Apache License 2.0
128 stars 31 forks source link

Return best n categories when predicting #60

Open sadrasabouri opened 2 years ago

sadrasabouri commented 2 years ago

Is your feature request related to a problem? Please describe. In my case (using PLDA for information retrieval) it'd better to predict [let's say] best n options instead of the best one for a given query. I figured out that the predict method does not support this feature. But it can be done using calc_logp_pp_categories method.

Describe the solution you'd like My fast solution for solving this was to use bellow code:

def predict_doc_at(query, k=1):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :param k: number of returning docs
    :type k: int 
    :return: return the document name
    """
    query_embedding = get_embeddings(query)
    data = PLDA_classifier.model.transform(query_embedding,
                                           from_space='D',
                                           to_space='U_model')
    logpps_k, K = PLDA_classifier.calc_logp_pp_categories(data,
                                                          False)
    best_k_idx = logpps_k.argsort()[::-1][:k]
    predictions = K[best_k_idx]
    return accuracy, predictions
RaviSoji commented 2 years ago

Thank you for writing this solution!

Would you mind making a pull request with the updated code and updating the Jupyter notebook with an example? I am happy to incorporate this.

Thanks again! Ravi B. Sojitra

sadrasabouri commented 2 years ago

Sure you can assign me this issue and I'll take care of if ASAP.