facebookresearch / StarSpace

Learning embeddings for classification, retrieval and ranking.
MIT License
3.95k stars 531 forks source link

Return nearestNeighbours #274

Closed TSPereira closed 4 years ago

TSPereira commented 5 years ago

Hi,

Is it possible to return the nearest neighbours to a variable instead of only printing them? This would by very useful since one might want to do something with these neighbours (my situation) and getting this info from the sys.stdout (using Python wrappers, but the issue is in starspace.cpp file) is kinda of messy.

Thanks

TSPereira commented 5 years ago

I made the necessary changes in my local files.

Under starspace.cpp the nearestNeighbor function now looks like:

string StarSpace::nearestNeighbor(const string& line, int k) {
  auto vec = getDocVector(line, " ");
  auto preds = model_->findLHSLike(vec, k);
  string output_preds = "";
  for (auto n: preds) {
    output_preds += dict_->getSymbol(n.first) + ' ' + to_string(n.second) + "\n";
    // cout << dict_->getSymbol(n.first) << ' ' << n.second << endl;
  }
  return output_preds;
}

and then also changed the type of the function in the starspace.h header to be a string

std::string nearestNeighbor(const std::string& line, int k);

Probably there are neater ways to write that output (never programmed in C++ before), but this is good enough for me as I do the subsequent parsing in python.

pmirla commented 5 years ago

Can you please share the updated python bindings? My need is to save nearest neighbors or classification copied in to a variable

TSPereira commented 5 years ago

If you change the files I mentioned above accordingly and rebuild starspace and python wrapper, then there aren't any changes to the bindings. The output of these changes will be a string of format: "<label_tag><name of label>+' '+<score>+'\n'+..."

And you can just make x = model.nearestNeighbors(args)

I then parse the string obtained with split and pandas to get it into an organized dataframe. Example:

res = model.nearestNeighbor(query, n)
res = pd.DataFrame(res.split('\n')).iloc[:, 0].str.split(' ', expand=True)
TSPereira commented 4 years ago

This has become obsolete with the addition of predictTags in the python binders. Closed