scverse / pertpy

Perturbation Analysis in the scverse ecosystem.
https://pertpy.readthedocs.io/en/latest/
MIT License
127 stars 20 forks source link

Streamlining Distance API #416

Open Zethson opened 10 months ago

Zethson commented 10 months ago

Description of feature

This is a continuation of https://github.com/theislab/pertpy/issues/405 but specific for Distance.

TLDR: Currently, Distance does not adhere to the API design of the rest of pertpy and I want to harmonize it. Currently, we pass a metric to the constructor which then uses the appropriate distance function on __call__. This comes with two issues:

  1. It's not consistent with the rest of the API
  2. https://pertpy.readthedocs.io/en/latest/usage/tools/pertpy.tools.Distance.html#pertpy.tools.Distance the options show up in a really long docstring list and if we wanted to document more, it becomes unreadable and hard to navigate

Currently we also have the

  1. onesided_distances
  2. pairwise
  3. precompute_distances

functions.

Moving metric into these 3 functions wouldn't really help or solve any issue. The only option I see is having functions like:

distance.compute_wasserstein(mode=Literal'onesided', 'pairwise', 'precompute']) for all of the metrics. These would then show up in a table of functions and can be documented more easily. It would also probably correspond better with the current design.

What do you think? I'm especially interested in @yugeji, @stefanpeidli, and @tessadgreen opinion.

stefanpeidli commented 10 months ago

Seems reasonable. No immediate problems this could cause come to my mind. And by having a function per metric we could add more docs including formulas, which I agree is nice.

yugeji commented 10 months ago

Two issues come to mind:

  1. The obvious - calling a set of distances one after another would look more ugly. For example, right now I can call

    for metric in metrics:
    distance = pt.tl.Distance(metric=metric)

    but with the proposed change, I would call

    for metric in metrics:
    distance = func_dict[metric](mode='onesided')
  2. Where does from_precomputed go? In the use case right now using the same distance object above, you would have called precompute_distances on an adata and then using the distance __call__ or either of .onesided(X,Y) or .pairwise(X, Y) would have made use of the precomputed distances:

    distance = pt.tl.Distance(metric='wasserstein')
    distance.precompute_distances(adata)
    df = distance(adata, groupby, etc.)

    In the proposed implementation, you would

    distance=pt.tl.Distance.compute_wasserstein(mode='precompute')
    distance(adata)
    distance=pt.tl.Distance.compute_wasserstein(mode='pairwise')  # using pairwise as an example, also where it makes the most sense
    df = distance(adata, groupby, etc.)

    In my opinion, this is considerably less readable and not intuitive. It also doesn't just apply to precompute but also to the case in which you want to calculate any summary statistic beforehand, which is what we definitely want to do because that's a major speedup.

  3. Only matters if you implement it this way, but it should be just wasserstein and not compute_wasserstein.

And just to clarify, you're thinking of using it like

distance=pt.tl.Distance.compute_wasserstein(mode='pairwise')
df = distance(adata, groupby, etc.)

NOT

distance=pt.tl.Distance()
df = distance.compute_wasserstein(mode='pairwise')(adata, groupby, etc.)

Right?

Zethson commented 10 months ago

Discussed a few things with @yugeji

  1. I need to find a way to make the looping easier
  2. Some distances like MMD could make use of specific parameters. Currently we could only use them as kwargs. The new design would allow us to properly document them.
  3. There's 3 modes that we need to support: 1 vs 1, pairwise, one-sided. While pairwise and one-sided would work as suggested above, we'd need to make 1 vs 1 first class AnnData supported. Currently these implementations are in the __call__ and only eat numpy arrays
  4. There'll be lots of docstring repetitions but we'll circumvent that with a docstring decorator
yugeji commented 10 months ago

For future distances that do not use what is currently the standard __call__ format of (X, Y), implementing it in the new way would let you override onesided with a distance-specific one (for example, with classifier class projection or KNN distance).

yugeji commented 9 months ago

An important addition to this refactor (which would also allow classifier_cp to be used with pairwise) would be to make pairwise include calls to onesided instead of using the copy-pasted code which is happening right now (and which is also causing problems).