theislab / ehrapy

Electronic Health Record Analysis with Python.
https://ehrapy.readthedocs.io/
Apache License 2.0
236 stars 19 forks source link

Master Issue: Function Compatibility with the Different Datatypes of AnnData Fields #829

Open eroell opened 21 hours ago

eroell commented 21 hours ago

Description of feature

Right now AnnData fields such as .X can have different datatypes, being e.g. numpy dense arrays, sparse arrays, or dask arrays.

Some ehrapy functions are compatible with multiple of these types, while some are not.

We can roll this out in a 2 step process:

  1. Introduce a consistent response for functions that fail for a given datatype. One way to implement this could be via consistent use of single-dispatch. Should also include dedicated tests.
  2. Address functions or batches of functions step by step, and extend the datatypes they support The order could be determined with volume of estimated usage, required effort, and typical data load (e.g. operations happening on full feature set vs lower-dimensional space)
nicolassidoux commented 17 hours ago

In addition to what @eroell mentioned, here are my initial—perhaps naive—thoughts on a potential strategy we could adopt:

  1. Assess and Implement Full Support for All Datatypes

    • We could evaluate each function that manipulates layer data and ensure it supports every possible datatype.
  2. Adopt a Standard Datatype with Conversion Extensions

    • Alternatively, we could standardize all functions to operate on a single datatype (for example, NumPy arrays :wink:).
    • Extensions could then handle the conversion of the real underlying datatype to NumPy for processing, and back to the original datatype afterward.

From a design perspective, strategy 2 is quite appealing:

However, there are downsides to consider:

Zethson commented 17 hours ago

Strategy 2 is more or less what we've been doing now because in almost all cases we densify. Speaking from experience, we are leaving a loooot of performance behind if we don't support sparse matrices. Moreover, enabling all algorithms with dask arrays is also important so that we can really scale and enable out-of-core computation. I'd not worry about cupy arrays and GPU for now.

I would:

  1. Ensure that every function converts to numpy array as a fall back if necessary.
  2. Try to implement all functions with sparse arrays (without densifying!) where possible. This alone will be substantial work but with substantial pay offs.
  3. Now do it again with Dask. This is again a massive massive pile of work but it will be super important to scale longterm.

We can break it up into pieces and also @aGuyLearning can potentially help here.

WDYT?

nicolassidoux commented 16 hours ago

I like this hybrid approach style.

I could see a classic dispatch function for the considered function (knn_imputefor example) and two single-dispatch for the conversion function (to_numpy, from_numpy).

Let's try with a quick example. knn_impute could be (just an example, syntax may be wrong)::

def knn_impute(
    adata: AnnData,
    ...
) -> AnnData:

    if isinstance(adata.X, sparse.csr_matrix):  # This datatype has a special algorithm available
        return _knn_impute_on_csr(adata, ...)
    if isinstance(adata.X, sparse.csc_matrix):  # This datatype also
        return _knn_impute_on_csc(adata, ...)
                                                # Note Dask is not supported

    # Fallback to numpy
    return from_numpy(_knn_impute_on_numpy(to_numpy(adata.X))) 

from_numpy will have to know somehow what was the original datatype.

Then one conversion function could be:

@singledispatch
def to_numpy(layer: Any) -> None:
    raise ValueError("Unsupported type") 

@to_numpy.register
def _(layer: np.ndarray) -> np.ndarray :
    return layer  # The layer is already a numpy array, then return it

@to_numpy.register
def _(layer: sparse.csr_matrix) -> np.ndarray:
    return layer.toarray()  # Will densify if I'm not wrong

@to_numpy.register
def _(layer: DaskArray) -> np.ndarray:
    return layer.whatever()  # Whatever needs to be done there

So to sum up, if the user calls knn_impute with a layer datatype that:

eroell commented 16 hours ago

You've spotted this right @nicolassidoux - however, optimization for sparse structures if the data is sparse is huge, can reach orders of magnitude easily depending on data sparsity. Also, out of core operations will not at all be possible without following Strategy 1 :)

Agree at all points with @Zethson.

We can break it up into pieces

absolutely, the high level steps 1 and 2 are both huge. They should rather guide a rollout for this behavior, not "a PR" ;) step 2 is in fact open-end.

I'll try to make a first set of tackeable portions soon for step 1 as sub-issues

eroell commented 16 hours ago

Rather than having to/from numpy functions singledispatched, and many names for _knn_impute_on_csr, _knn_impute_on_csc, I right now have this pattern in mind:

(Imagine sum_entries to be _knn_impute, where data type specific operations are performed)

@singledispatch
def sum_entries(data):
    """
    Sum entries
    Arguments
    ---------
    data
        Data to sum entries of.

    Returns
    -------
    result
        Sum of entries.
    """
    raise TypeError(f"Unsupported data type: {type(data)}")

@sum_entries.register
def _(data: np.ndarray):
    print("Summing entries of a NumPy array.")
    return np.sum(data)

@sum_entries.register
def _(data: sparse.spmatrix):
    print("Summing entries of a SciPy sparse matrix.")
    return data.sum() # not so exciting example, just illustrates this could look different for this data type

To implement this step by step for (batches of) our functions together with tests would form the implementation of high-level step 1:

eroell commented 16 hours ago

Raise of TypeError for everything not explicitly ok comes in very naturally here

nicolassidoux commented 16 hours ago

Looks good to me also, but beware of the limitation of singledispatch: it overloads the function based on the first argument only.

I have a suggestion for a first step, if I may: we need to list every function in ehrapy that can manipulate data so we have a clear idea of what needs to be done.