lmcinnes / pynndescent

A Python nearest neighbor descent for approximate nearest neighbors
BSD 2-Clause "Simplified" License
899 stars 105 forks source link

Importing pickled index gives error - 'NNDescent' object has no attribute 'shape' #197

Open regstuff opened 2 years ago

regstuff commented 2 years ago

Hi Using the latest version via pip install. Running this on Google Colab with python 3.7.13 Followed the Docs and created an index with the params pynnindex = pynndescent.NNDescent(arr, metric="cosine", n_neighbors=100) Everything works fine and I get results from pynnindex.neighbor_graph as expected.

Then I pickled the index like so:

with open('pynnindex','wb') as f:
  pickle.dump(pynnindex,f)

The trouble starts when I try to load the pickled index later on, like so:

with open('pynnindex','rb') as f:
  msgembed = pickle.load(f)

I get the following error:

AttributeError                            Traceback (most recent call last)
[<ipython-input-7-44d19f2c98cc>](https://localhost:8080/#) in <module>
      3   msgembed = pickle.load(f)
      4 
----> 5 pynnindex_p = pynndescent.NNDescent(msgembed)

[/usr/local/lib/python3.7/dist-packages/pynndescent/pynndescent_.py](https://localhost:8080/#) in __init__(self, data, metric, metric_kwds, n_neighbors, n_trees, leaf_size, pruning_degree_multiplier, diversify_prob, n_search_trees, tree_init, init_graph, random_state, low_memory, max_candidates, n_iters, delta, n_jobs, compressed, parallel_batch_queries, verbose)
    671 
    672         if n_trees is None:
--> 673             n_trees = 5 + int(round((data.shape[0]) ** 0.25))
    674             n_trees = min(32, n_trees)  # Only so many trees are useful
    675         if n_iters is None:

AttributeError: 'NNDescent' object has no attribute 'shape'

The error crops up even after running prepare() on the index. Tried loading the index on another system (with python 3.9.0) and got a different error:

Traceback (most recent call last):
  File "D:\Code\markupf\pynn.py", line 6, in <module>
    msgembeds = pickle.load(f)
  File "D:\Code\gnanetra\gnanetra\lib\site-packages\numba\core\serialize.py", line 97, in _unpickle__CustomPickled
    ctor, states = loads(serialized)
TypeError: an integer is required (got type bytes)

Essentially, to use the index, I need to build it every time. Can anyone point me towards a resolution.

Thanks

sky-2002 commented 1 year ago

Hi @regstuff According to me, the following is the problem: You are doing -

with open('pynnindex','wb') as f:
  pickle.dump(pynnindex,f)

Which means you are pickling the index itself.

Now when you do this:

with open('pynnindex','rb') as f:
  msgembed = pickle.load(f)

You have loaded the index, which is same as the pynnindex, so now msgembed is index.

Finally when you do this:

pynnindex_p = pynndescent.NNDescent(msgembed)

You are trying to create an index on an index.

Instead just use the msgembed as the index, like you can do msgembed.query(...) and such.