hosseinfani / ReQue

A Benchmark Workflow and Dataset Collection for Query Refinement
https://hosseinfani.github.io/ReQue/
Other
25 stars 1 forks source link

glove.py get_expanded_query() is slow,potential improvement #33

Open Woracle opened 2 years ago

Woracle commented 2 years ago

So been playing around with the different expanders and found that the glove get_expanded_query() method was quite slow.

realised line 33 looping through the comparisons was the main culprit slowing the method. w = sorted(Glove.glove.keys(), key=lambda word: scipy.spatial.distance.euclidean(Glove.glove[word], Glove.glove[qw]))

Reviewing scipy.spatial.distance.cdist i was able to see that we could run all the distance comparisons at the same time and improve speed. Combined with np.argpartition() i was able to get the index for the top n words for each array output form cdist and retrieve the words from the model using index. (the model being the glove dictionary)

in my testing I got the same outputs in a shorter of the time using short queries. "a nice lemon pie" went from 11seconds -> 1,4seconds with no change in output.

below is reduced snippet showing basic approach applied. Anyway I found this repos really helpful so thought I would share suggestion for anyone else benefit

def get_expanded_query(model, query, topn = 5, coef = 0.7):
    embs = []
    terms = []
    for qw in query:
        if qw.lower() in model.keys():
            embs.append(model[qw])
        else:
            terms.append([qw, 1])
    # we now will stack the embeddings into an array which will allow us to use cdist to compare all words much faster. 
    embs = np.vstack(embs)
    nums = np.vstack(list(model.values()))

    model_list = list(model) # allows to to use indexing on the model keys.

    # cdist performs a pairwise comparison for each pairs generating array shape [Query length (post cleaning) , number of words in model]
    matrix = scipy.spatial.distance.cdist(embs, nums)

    # we have a array of pairwise distance for each term. We can use arg parse to get the top n indexs for each word. then just look up that work in the model.   
    words = [model_list[idx] for idx in  np.argpartition(matrix, topn)[:, :topn].flatten()]
    return ' '.join(words)
hosseinfani commented 2 years ago

Hi @Woracle Thank you for the speed improvement. Please send a pr, and I will merge it.