scverse / scanpy

Single-cell analysis in Python. Scales to >1M cells.
https://scanpy.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.83k stars 589 forks source link

Subsample by observations grouping #987

Open chsher opened 4 years ago

chsher commented 4 years ago

Related to scanpy.pp.subsample, it would be useful to have a subsampling tool that subsamples based on the key of an observations grouping. E.g., if I have an observation key 'MyGroup' with possible values ['A', 'B'], and there are 10,000 cells of type 'A' and 2,000 cells of type 'B' and I want only max 5,000 cells of each type, then this function would subsample 5,000 cells of type 'A' but retain all 2,000 cells of type 'B'.

LuckyMD commented 4 years ago

Something like this should work. Note, this is not tested.

target_cells = 5000

adatas = [adata[adata.obs[cluster_key].isin(clust)] for clust in adata.obs[cluster_key].cat.categories]

for dat in adatas:
    if dat.n_obs > target_cells:
         sc.pp.subsample(dat, n_obs=target_cells)

adata_downsampled = adatas[0].concatenate(*adatas[1:])

Hope that helps.

chsher commented 3 years ago

Thank you @LuckyMD, it worked!

giovp commented 2 years ago

I'll reopen this cause I think it's quite relevant still and could be very straightforward to implement with sklearn resample

also, there is an entire package for subsampling strategies which is probably quite relevant: https://github.com/scikit-learn-contrib/imbalanced-learn

line here for reference: https://github.com/theislab/scanpy/blob/48cc7b38f1f31a78902a892041902cc810ddfcd3/scanpy/preprocessing/_simple.py#L857

giovp commented 2 years ago

back here reminding myself that this would be very useful feature to have...

ivirshup commented 2 years ago

@bio-la also expressed some interest here on MM

@giovp, did you have a particular strategy in mind for resampling?

giovp commented 2 years ago

So assuming that we are only interested in downsampling, then I'd say NearMiss and related are straightforward and scalable (just need to compute a kmeans whcih is really fast)

giovp commented 2 years ago

also, the fact that reshuflling is performed is not in docs and should be documented. @bio-la do you plan to work on this?

ivirshup commented 2 years ago

then I'd say NearMiss and related are straightforward and scalable (just need to compute a kmeans whcih is really fast)

For sampling from datasets, I would want to go with either extremely straightforward or something that has been shown to work. Maybe we could start with use provided labels to downsample by?

reshuflling is performed

Reshuffling meaning that the order is changed?

ivirshup commented 2 years ago

Linking some previous discussion:

chansigit commented 2 years ago
clust

in scanpy1.8 , this works

`target_cells = 3000

adatas = [adata_train[adata_train.obs[cluster_key].isin([clust])] for clust in adata_train.obs[cluster_key].cat.categories]

for dat in adatas: if dat.n_obs > target_cells: sc.pp.subsample(dat, n_obs=target_cells, random_state=0)

adata_train_downsampled1 = adatas[0].concatenate(*adatas[1:])`

stefanpeidli commented 1 year ago

This function at least subsamples all classes in an obs column to the same number of cells. Would be straightforward to modify to what you probably think of.

def obs_key_wise_subsampling(adata, obs_key, N):
    '''
    Subsample each class to same cell numbers (N). Classes are given by obs_key pointing to categorical in adata.obs.
    '''
    counts = adata.obs[obs_key].value_counts()
    # subsample indices per group defined by obs_key
    indices = [np.random.choice(adata.obs_names[adata.obs[obs_key]==group], size=N, replace=False) for group in counts.index]
    selection = np.hstack(np.array(indices))
    return adata[selection].copy()
royfrancis commented 1 year ago

@stefanpeidli's code gives this error

ValueError: Cannot take a larger sample than population when 'replace=False'

If a group has less than required number observations, it shouldn't subsample.

target_cells = 1000
cluster_key = "cell_type"

grouped = adata.obs.groupby(cluster_key)
downsampled_indices = []

for _, group in grouped:
    if len(group) > target_cells:
        downsampled_indices.extend(group.sample(target_cells).index)
    else:
        downsampled_indices.extend(group.index)

adata_downsampled = adata[downsampled_indices]