CamaraLab / CAJAL

A Python package using Gromov-Wasserstein distance to compare cell shapes
MIT License
15 stars 5 forks source link

Looking for clarification on implementation/efficiency #12

Closed tomouellette closed 1 year ago

tomouellette commented 1 year ago

Hi there @patrick-nicodemus! Nice paper - I am just opening an issue to get some clarification on the implementation computational complexity of the model. I've been putting together a general toolkit of models/methods for morphometrics for a bit now (in 2D for right now) and the idea of using GW distance as a morphometric descriptor seems pretty cool.

Onto my question/comment:

I just wanted to clarify that if I strip away all the boilerplate related to neuron data I/O that this is your implementation (for 2-dimensions). And that the input is (K points, K points) shape outline pairwise distance matrices and that the output is (N shapes, N shapes) pairwise GW distance matrix.

# The distance function as defined in your code
def gw(fst_mat: np.ndarray, snd_mat: np.ndarray) -> float:
    _, log = ot.gromov.gromov_wasserstein(
        fst_mat,
        snd_mat,
        ot.unif(fst_mat.shape[0]),
        ot.unif(snd_mat.shape[0]),
        "square_loss",
        log=True,
    )
    return log["gw_dist"]

# The graph/matrix output as described in your paper
def gw_matrix(outlines):
    n = len(outlines)
    gw_distances = np.zeros((n,n))
    for index in [*itertools.combinations(np.arange(n), 2)]:
        i, j = index        
        a = cdist(outlines[i].T, outlines[i].T)
        b = cdist(outlines[j].T, outlines[j].T)
        gw_distances[i,j] = gw_distances[j,i] = gw(a,b)

    return gw_distances

Comment on improving efficiency

So it looks like there is two spots of quadratic complexity (ignoring GW computation): (1) computation of "intra-shape" pairwise distance matrix and (2) computation of GW distances between all pairwise distance matrices? If I don't have it wrong, the major speed bottleneck is the pairwise GW distance computation across all cells, etc. This appears to be a bit of a problem for scaling to a very large number of cells.

To improve speed and enable scaling to larger datasets, have you thought about first building a nearest neighbour graph using a set of morphometric descriptors that are correlated to GW at a lower level of granularity? For example, maybe major axis length computed from the best-fitting ellipse can reduce the number of pairwise GW computations (neighbours) for each cell by 1/2.

Thanks, Tom

patrick-nicodemus commented 1 year ago

Hi Tom; These are great questions.

  1. Yes. Mathematically, this is the core of what's happening. Our software is in large part a tool for allowing users to convert various data formats for cell morphology into an intracell distance matrix. For 2D image segmentation data your code captures essentially what's happening here. Note that you could use pdist rather than cdist here. A warning: the OT library reports the squared GW distance rather than the GW distance. This is probably not a metric and it would be best (in the sense of choosing a standard for clear communication) to use the term "GW distance" to refer to the one in Memoli's paper, which is a metric, rather than the squared distance. I have pushed a fix for this.
  2. Your efficiency observations are very much correct. The computation of pdist/cdist is extremely fast and you can ignore it for the sake of performance relative to the GW computation. On my machine (one python process with openblas multithreading disabled) I get these runtimes for GW between two cells. The first column is the number of points taken from each cell.
    | Resolution | Time (ms) |
    |------------+-----------|
    |         10 |      0.47 |
    |         20 |      0.91 |
    |         30 |      1.53 |
    |         40 |      2.25 |
    |         50 |      3.09 |
    |         60 |      4.20 |
    |         70 |      5.68 |
    |         80 |      7.52 |
    |         90 |      9.32 |
    |        100 |     11.83 |

    I did a best fit with a power law and I got $y = Ax^n$ where $n=1.416$ and $A=0.0141$ ms, to give a rough idea of the scaling with cell size in this region. When $x$ is small the runtime cost is dominated by optimal transport. When $x$ is large, the runtime starts to be dominated by matrix multiplication operations and so approaches $\mathcal{O}(n^3)$ or so. image

As you say, the quadratic growth with number of cells is indeed a serious chokepoint. We are working on this problem in a development branch, "optimize", which should be released later this week. Reading the documentation (under construction) for that branch in our readthedocs may give some insight.

In this new branch we have done three things:

  1. Rewrote the gradient descent algorithm for computing GW in Cython and made small changes to improve performance, perhaps a 20% speed improvement.
  2. Closest to your suggestion, we have implemented a lower bound for GW from Memoli's original paper, called SLB, and we can use this to filter out cell pairs which are sufficiently far apart that they cannot be within each other's nearest neighbor list. This technique is general and does not require a notion of major/minor axis, it is compatible with geodesic distance. It is multiple orders of magnitude faster than GW. This reduces the number of cell pairs to compute by about 70%, if the GW distance only needs to be known precisely for the nearest neighbors.
  3. Implemented an approximation to GW, called quantized GW, which involves replacing the cell by a subset of its points and using this as an approximation to the original space. The approximation is acceptable and offers a substantial reduction in computation time; as shown in the table above, approximating a 100pt space by a 30pt space gives an 87% reduction in computation time.

Lastly, combining all of these, we have written an algorithm to compute the SLB pairwise for all cells, and use this as a basis to compute only the quantized GW distance between nearby cells, up to an acceptable (user-specified) degree of error in the reported nearest neighbor table. We hope that these improvements will be enough to realistically apply CAJAL to very large datasets such as the MICrONS dataset.

Thanks for your interest and please feel free to comment on the code or documentation!

tomouellette commented 1 year ago

Hi @patrick-nicodemus, thanks for taking the time with the detailed response! It sounds like big speed reductions are coming (which is awesome) and it seems like CAJAL is a great choice for analyzing highly projected morphologies. Looking forward to seeing the updates. I'll close up the issue.