lmcinnes / umap

Uniform Manifold Approximation and Projection
BSD 3-Clause "New" or "Revised" License
7.45k stars 808 forks source link

Cannot transform sparse data with Jaccard metric #680

Open quietrainfall opened 3 years ago

quietrainfall commented 3 years ago

UMAP can fit sparse data with the Jaccard metric but it cannot transform new sparse data. Minimum example below:

from scipy.sparse import random
from umap import UMAP

# create random training data
train = random(m = 100, n = 1000, density = 0.01)

# initialize model
mapper = UMAP(metric = 'jaccard')

# fit UMAP with Jaccard metric
embedding = mapper.fit_transform(train)

# transform training data
output = mapper.transform(train)

Interestingly, mapper.fit_transform works fine but mapper.transform fails with the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/umap/umap_.py in transform(self, X)
   2700                 dmat = pairwise_distances(
-> 2701                     X, self._raw_data, metric=_m, **self._metric_kwds
   2702                 )

/opt/conda/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 

/opt/conda/lib/python3.7/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
   1767         if issparse(X) or issparse(Y):
-> 1768             raise TypeError("scipy distance metrics do not"
   1769                             " support sparse matrices.")

TypeError: scipy distance metrics do not support sparse matrices.

During handling of the above exception, another exception occurred:

TypingError                               Traceback (most recent call last)
<ipython-input-41-4c8b67f1bd06> in <module>
----> 1 mapper.transform(train)

/opt/conda/lib/python3.7/site-packages/umap/umap_.py in transform(self, X)
   2706                     self._raw_data,
   2707                     metric=self._input_distance_func,
-> 2708                     kwds=self._metric_kwds,
   2709                 )
   2710             indices = np.argpartition(dmat, self._n_neighbors)[:, : self._n_neighbors]

/opt/conda/lib/python3.7/site-packages/umap/distances.py in pairwise_special_metric(X, Y, metric, kwds)
   1260             return metric(_X, _Y, *kwd_vals)
   1261 
-> 1262         return pairwise_distances(X, Y, metric=_partial_metric)
   1263     else:
   1264         special_metric_func = named_distances[metric]

/opt/conda/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     61             extra_args = len(args) - len(all_args)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 
     65             # extra_args > 0

/opt/conda/lib/python3.7/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
   1788         func = partial(distance.cdist, metric=metric, **kwds)
   1789 
-> 1790     return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1791 
   1792 

/opt/conda/lib/python3.7/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1357 
   1358     if effective_n_jobs(n_jobs) == 1:
-> 1359         return func(X, Y, **kwds)
   1360 
   1361     # enforce a threading backend to prevent data communication overhead

/opt/conda/lib/python3.7/site-packages/sklearn/metrics/pairwise.py in _pairwise_callable(X, Y, metric, force_all_finite, **kwds)
   1401         iterator = itertools.product(range(X.shape[0]), range(Y.shape[0]))
   1402         for i, j in iterator:
-> 1403             out[i, j] = metric(X[i], Y[j], **kwds)
   1404 
   1405     return out

/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    412                 e.patch_message(msg)
    413 
--> 414             error_rewrite(e, 'typing')
    415         except errors.UnsupportedError as e:
    416             # Something unsupported is present in the user code, add help info

/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    355                 raise e
    356             else:
--> 357                 raise e.with_traceback(None)
    358 
    359         argtypes = []

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /opt/conda/lib/python3.7/site-packages/umap/distances.py (1260)

File "../../opt/conda/lib/python3.7/site-packages/umap/distances.py", line 1260:
        def _partial_metric(_X, _Y=None):
            return metric(_X, _Y, *kwd_vals)
            ^

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'scipy.sparse.csr.csr_matrix'>
- argument 1: Cannot determine Numba type of <class 'scipy.sparse.csr.csr_matrix'>
lmcinnes commented 3 years ago

This is definitely a bug. Mostly just in handling all the different cases -- one case has been missed here. I will try to get this fixed soon.