scverse / mudata

Multimodal Data (.h5mu) implementation for Python
https://mudata.rtfd.io
BSD 3-Clause "New" or "Revised" License
72 stars 17 forks source link

`MuData.to_anndata()` & `AnnData.to_mudata()` #14

Closed ivirshup closed 2 months ago

ivirshup commented 2 years ago

I'd like to propose some utility methods for converting back and forth between MuData and AnnData objects.

Implementations would look something like:

import anndata as ad
import mudata

def to_anndata(self: mudata.MuData, axis, ...) -> ad.AnnData:
    adata = ad.concat(self.mod, axis=axis, ...)
    for attr in ["obs", "obsm", "obsp", "var", "varm", "varp", "uns"]:
        getattr(adata, attr).update(getattr(self, attr))
    return adata

def to_mudata(self: ad.AnnData, axis, groupby) -> mudata.MuData:
    # self.split_by(groupby, axis=axis) -> dict[str, AnnData]
    return mudata.MuData(
        self.split_by(groupby, axis=axis)
    )

With split_by being: https://github.com/theislab/anndata/pull/613

Additional needs

I think it would be good to have arguments giving some control over which features are moved to/ from shared mappings in the mudata object, as opposed to being anndata specific.

Decisions on copies vs views would be good as well.

Use cases

The main use case here is easing transition for libraries that only work with one of the types. I think it would also make it much easier to play around with which representation fits a particular use case better.

It also does a good job of advertising the connectedness of the objects, but that's very secondary.

gtca commented 12 months ago

The API should come in v0.3. As the split_by PR hasn't been merged, the to_mudata functionality might be a bit basic but should cover the main use case.

gtca commented 2 months ago

This now works:

import jax
import numpy as np
from mudata import *

x = np.array(jax.random.normal(jax.random.PRNGKey(1), (100, 300)))
adata = AnnData(x)
adata.var["split"] = np.array(jax.random.binomial(jax.random.PRNGKey(1), 1, .5, (adata.n_vars,)))
adata.var["split"] = adata.var["split"].map({0.0: 'a', 1.0: 'b'})

to_mudata(adata, axis=0, by='split')
# MuData object with n_obs × n_vars = 100 × 300
#   var:    'split'
#   2 modalities
#     a:    100 x 139
#       var:    'split'
#     b:    100 x 161
#       var:    'split'

mdata.to_anndata()
# AnnData object with n_obs × n_vars = 100 × 300
#     var: 'split'
#     obsm: 'a', 'b'
#     varm: 'a', 'b'