pavlin-policar / openTSNE

Extensible, parallel implementations of t-SNE
https://opentsne.rtfd.io
BSD 3-Clause "New" or "Revised" License
1.48k stars 165 forks source link

Unable to use custom callable metric #239

Closed tctco closed 1 year ago

tctco commented 1 year ago
Steps to reproduce the behavior
from openTSNE import TSNE
from scipy.spatial.distance import jensenshannon

tsne = TSNE(
    perplexity=30,
    metric=jensenshannon,
    n_jobs=8,
    random_state=0,
    verbose=True,
)
embedding_train = tsne.fit(scores_list)
plot(embedding_train, ids)

The above code throws an error:

TypeError                                 Traceback (most recent call last)
Cell In[36], line 97
     88     ax.legend(handles=legend_handles, **legend_kwargs_)
     90 tsne = TSNE(
     91     perplexity=30,
     92     metric=jensenshannon,
   (...)
     95     verbose=True,
     96 )
---> 97 embedding_train = tsne.fit(scores_list)
     98 plot(embedding_train, ids)

File [~/anaconda3/envs/openmmlab/lib/python3.9/site-packages/openTSNE/tsne.py:1246](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/tc/open-mmlab/tracker/~/anaconda3/envs/openmmlab/lib/python3.9/site-packages/openTSNE/tsne.py:1246), in TSNE.fit(self, X, affinities, initialization)
   1243 if self.verbose:
   1244     print("-" * 80, repr(self), "-" * 80, sep="\n")
-> 1246 embedding = self.prepare_initial(X, affinities, initialization)
   1248 try:
   1249     # Early exaggeration with lower momentum to allow points to find more
   1250     # easily move around and find their neighbors
   1251     embedding.optimize(
   1252         n_iter=self.early_exaggeration_iter,
   1253         exaggeration=self.early_exaggeration,
   (...)
   1256         propagate_exception=True,
...
--> 247 self.index = AnnoyIndex(data.shape[1], annoy_metric)
    249 random_state = check_random_state(self.random_state)
    250 self.index.set_seed(random_state.randint(np.iinfo(np.int32).max))

TypeError: argument 2 must be str, not function
dkobak commented 1 year ago

Hmm. AFAIK this is not supposed to work like this.

You can try metric="jensen-shannon" which will call pynndescent (which you should install for this work).

tctco commented 1 year ago

Thanks for your advice!