scverse / spatialdata

An open and interoperable data framework for spatial omics data
https://spatialdata.scverse.org/
BSD 3-Clause "New" or "Revised" License
196 stars 34 forks source link

Subset spatialdata by list of cell ids #556

Open RAC1924 opened 2 months ago

RAC1924 commented 2 months ago

Hi,

Thank you for the spatialdata package. I'm just wondering if it's possible to subset our spatialdata object directly by a list of cell ids? I couldn't find it in the documentation

BW, Ros

aeisenbarth commented 1 month ago

The method SpatialData.subset takes only element names.

But you can further filter the returned object: (This code is still single-table.)

from collections.abc import Iterable
from typing import Optional
import numpy as np

def filter_spatialdata(
    sdata: "SpatialData",
    elements: Optional[Iterable[str]] = None,
    regions: Optional[Iterable[str]] = None,
    obs_keys: Optional[Iterable[str]] = None,
    var_keys: Optional[Iterable[str]] = None,
    var_names: Optional[Iterable[str]] = None,
    layers: Optional[Iterable[str]] = None,
    region_key: str = "region",
    instance_key: str = "instance_id",
) -> "SpatialData":
    """
    Filter a SpatialData object to contain only specified elements or table entries.

    Args:
        sdata: A SpatialData object
        elements: Names of elements to include. Defaults to [].
        regions: Regions to include in the table. Defaults to regions of all selected elements.
        obs_keys: Names of obs columns to include. Defaults to [].
        var_keys: Names of var columns to include. Defaults to [].
        var_names: Names of variables (X columns) to include. Defaults to [].
        layers: Names of X layers to include. Defaults to [].

    Returns:
        A new SpatialData instance
    """
    from anndata import AnnData
    from spatialdata import SpatialData
    from spatialdata.models import TableModel
    from spatialdata_plot.pp import PreprocessingAccessor

    elements = [] if elements is None else list(elements)

    sdata.pp: PreprocessingAccessor  # noqa: F401
    sdata_subset = (
        sdata.subset(element_names=elements, filter_tables=True) if elements else SpatialData()
    )
    # Ensure the returned SpatialData is not backed to the original reference dataset,
    # so that it can be safely modified.
    assert not sdata_subset.is_backed()
    # Further filtering on the table
    if (table := sdata_subset.tables.get("table")) is not None:
        regions = elements if regions is None else regions
        obs_keys = [] if obs_keys is None else list(obs_keys)
        if instance_key not in obs_keys:
            obs_keys.insert(0, instance_key)
        if region_key not in obs_keys:
            obs_keys.insert(0, region_key)
        var_keys = [] if var_keys is None else var_keys
        var_names = [] if var_names is None else var_names
        # Preserve order by checking "isin" instead of slicing. Also guarantees no duplicates.
        table_subset = table[table.obs[region_key].isin(regions), table.var_names.isin(var_names),]
        layers_subset = (
            {key: layer for key, layer in table_subset.layers.items() if key in layers}
            if table_subset.layers is not None and len(var_names) > 0
            else None
        )
        table_subset = TableModel.parse(
            AnnData(
                X=table_subset.X if len(var_names) > 0 else None,
                obs=table_subset.obs.loc[:, table_subset.obs.columns.isin(obs_keys)],
                var=table_subset.var.loc[:, table_subset.var.columns.isin(var_keys)],
                layers=layers_subset,
            ),
            region_key=region_key,
            instance_key=instance_key,
            region=np.unique(table_subset.obs[region_key]).tolist(),
        )
        del sdata_subset.tables["table"]
        sdata_subset.tables["table"] = table_subset
    return sdata_subset

The table is a normal AnnData object. It is important to know that AnnData uses obs_names as string index, which does not necessarily match labels of a labels image or indices of shapes/points. So you need to filter by the instances column (and evt. region column):

>>> instance_key = "instance_id"
>>> cell_ids = [1, 2, 4]
>>> table = sdata.tables["table"]
>>> table[table.obs[instance_key].isin(cell_ids), :]
AnnData object with n_obs × n_vars = 3 × 0
    obs: 'region', 'instance_id', 'some_other_columns'

For filtering the data, it depends on what type of element your cell IDs refer to:

melonora commented 1 week ago

Hi there @RAC1924, is this issue still relevant or is this sufficient information given by @aeisenbarth ?

melonora commented 20 hours ago

Closing this for now, but feel free to reopen

LucaMarconato commented 4 minutes ago

@melonora I'll reopen because now that the join operations are available we could provide a more general implementation, and also supporting multiple tables. Multiple users asked for this APIs, so I'll add this issue to my TODO list.