Open adamgayoso opened 2 years ago
I'm hesitant to add another embedding method without clear benefits over existing implementations.
Could you give some detail on benefits here, ideally with direct comparisons?
See colab notebook.
With random initialization it's about 10x faster than UMAP on this system. The quadratic init (default) is as fast as UMAP, but there's an opportunity to optimize that code to use the GPU.
Hi @ivirshup; I am the creator of PyMDE. I'm happy to provide some context. I'd like to say up front though that I am not very familiar with Scanpy (which seems like a fantastic library!), so please bear with me if some of the things I mention are not relevant to Scanpy.
PyMDE (documentation here: https://pymde.org/) has a few benefits:
There are some comparisons to UMAP & openTSNE in the third part of our manuscript, which has been published in Foundations & Trends in Machine Learning and is available here: https://web.stanford.edu/~boyd/papers/pdf/min_dist_emb.pdf
On the other hand, PyMDE is young software. If you do depend on it, I would recommend including it as an optional dependency, not a required one.
Happy to chat more, to answer any questions, and to help with integration, if that is something you are ultimately interested in.
I just updated the notebook linked at the top of the PR. I have a PR at pymde to improve the initialization speed using the GPU (https://github.com/cvxgrp/pymde/pull/55). Using these changes, pymde takes 20 seconds and umap takes around 200 seconds (150k cells). Most of the time of pymde I believe is from the initial pynndescent call. Therefore, if implemented well here and therefore using a precomputed neighbors graph, pymde would take no more than a few seconds for most use cases.
@akshayka thanks for contributing to the conversation! The package does indeed look interesting. A couple questions about the tool:
Can we get a weighted graph out of the fit embedding object? For context, we use the UMAP weighted connectivity graph for a number of downstream tasks. This seems related to distortions, but maybe not quite what they are.
I'm also wondering about just how early the package is. I would like to be able to take advantage of any new features, and wouldn't want an early API decision to lock us out of those.
Can we get a weighted graph out of the fit embedding object?
Yes, PyMDE can do that. The longer answer is that it depends on the type of embedding problem you set up --- some are specified using weighted graphs, others are not. But most embedding problems (including all problems specified using the preserve_neighbors
function, which is the most commonly used recipe) have associated weighted graphs.
I'm also wondering about just how early the package is. I would like to be able to take advantage of any new features, and wouldn't want an early API decision to lock us out of those.
Great question. The internals will very likely change over the coming months/year. But the interface to the MDE
class, which is the central object in PyMDE, will likely be stable.
The internals will very likely change over the coming months/year.
Good to know, thanks! Any hints about what will change here? In particular, I'm wondering if there might be a jax
implementation as I'm a bit more keen on that as a dependency.
But most embedding problems (including all problems specified using the preserve_neighbors function, which is the most commonly used recipe) have associated weighted graphs.
I'd be interested in seeing how these graphs perform compared to the ones we get from UMAP. Would this be the right way to retrieve the graphs for the object, or is distortions
not the right field?
from scipy import sparse
weights = mde.distortions().cpu().numpy()
edges = mde.edges.cpu().numpy()
graph = sparse.coo_matrix((weights, (edges[:, 0], edges[:, 1])), shape=(mde.n_items, mde.n_items))
I'm wondering if there might be a jax implementation as I'm a bit more keen on that as a dependency.
Probably for another discussion -- I like jax as much as anyone, but it's not nearly as easy to install as pytorch, especially on windows and m1 mac.
In particular, I'm wondering if there might be a jax implementation as I'm a bit more keen on that as a dependency.
I don't have any plans to switch from PyTorch to JAX. I did evaluate JAX when I started the project, but it wasn't mature enough back then.
I'd be interested in seeing how these graphs perform compared to the ones we get from UMAP.
I'm not super clear on the semantics of the graphs obtained from UMAP. They might differ somewhat from the ones obtained from PyMDE.
Would this be the right way to retrieve the graphs for the object, or is distortions not the right field?
That's not quite right. Assuming that mde
was constructed from preserve_neighbors
, try this:
weights = mde.distortion_function.weights.cpu().numpy()
edges = mde.edges.cpu().numpy()
n_items = mde.n_items
graph = pymde.Graph.from_edges(edges, weights, n_items).adjacency_matrix
(API docs for Graph
here: https://pymde.org/api/index.html#pymde.Graph. In the Graph class, distances/weights are used interchangeably.)
I'll just mention however that with PyMDE, the weights and edges don't fully determine the embedding. The weights are parameters to distortion functions, which convey the extent to which two items are similar or dissimilar. Roughly speaking positive weights mean items are similar and should be close together, and negative weights mean that they're dissimilar and shouldn't be close (but need not be far). More details here:https: //pymde.org/mde/index.html
sc.tools
?sc.pl
?sc.external.*
?PyMDE is a nice visualization method that to me seems to effectively serve the same purpose as UMAP in analyses (discussion about appropriateness of these methods can be in another issue :) ). It's super fast because after running pynndescent it puts the graph on the GPU optionally (using pytorch). I would love to see this in scanpy. There might be a way to use the scanpy neighbors graph from
sc.pp.neighbors
directly in pymde as the function below is a wrapper of some internal classes.