BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
304 stars 56 forks source link

Cell2Location on large datasets #356

Open LinearParadox opened 5 months ago

LinearParadox commented 5 months ago

Is there any way to manage memory useage on large datasets? For example, when you're approaching ~40000 spots and ~10000 genes, memory use becomes huge. Is there a way to train seperate conditions, and then somehow merge the results? I'm not sure if this would be a simple merge between anndata objects or if it's a more involved process.

vitkl commented 5 months ago

This is indeed a common problem. An approach we started using recently is splitting the data randomly (stratified by batch and possibly by anatomy annotation), training several models, then merging results.

LinearParadox commented 5 months ago

Is there any specific function you use to merge the data? Or is merging the anndata objects good enough? I'm not sure how I would merge several models, especially for some of the downstream analysis functions.

vitkl commented 5 months ago

You can concatenate anndata. There are two important parts to consider - cell abundance in adata.obsm and all other model parameters (+ essential record about input data) in adata.uns['mod']. Merging adata.uns['mod'] correctly is going to be more complex but you can save them as different adata.uns slots.

As a technical note, merging adata.uns['mod'] would require some way to properly combine location-independent parameters - which is straightforward for technical gene-specific effects but can be less straightforward for prior factorisation of cell abundance.

Rafael-Silva-Oliveira commented 4 months ago

This is indeed a common problem. An approach we started using recently is splitting the data randomly (stratified by batch and possibly by anatomy annotation), training several models, then merging results.

Is there any improvements coming soon in cell2location to take in the full data? That approach of splitting in batches removes any spatial information that might be used to do cell deconvolution

With the new VisiumHD being more proeminent, the number of spots or bins can range anywhere from 150k to 650k+, so when I test with cell2location, I always get memory errors.

And doing in batches also doesn't work:

image

Any near future implementations to support datasets such as VisiumHD?

Li-ZhiD commented 3 months ago

Hi, I am using cell2location to deconve large Stereo-seq data. 1, Could you show how to merge the trained mod by splitting spatial data? 2, How much does it affect the results by reducing 30000 (long time in CPU mode) to 300 of max_epochs in training spatial data? Thank you!

vitkl commented 1 month ago

@Li-ZhiD @Dillon214

Split the object into batches

# for every "sample", sample with replacement from chunks to allocate 
# some locations from each batch to all training batches
chunk_size = 72_000
chunks = [i for i in range(int(np.ceil(adata_vis.n_obs / chunk_size)))]

adata_vis.obs['training_batch'] = 0
for sample in adata_vis.obs['sample'].unique():
    ind = adata_vis.obs['sample'].isin([sample])
    adata_vis.obs.loc[ind, 'training_batch'] = np.random.choice(
        chunks, size=ind.sum(), replace=True, p=None
    )

adata_vis_full = adata_vis.copy()
for k in ['means', 'stds', 'q05', 'q95']:
    adata_vis_full.obsm[f"{k}_cell_abundance_w_sf"] = np.zeros((adata_vis_full.n_obs, inf_aver.shape[1]))

adata_vis.obs['training_batch'].value_counts()

Train

seed = 0
scvi.settings.seed = seed
np.random.seed(seed)

# submit this chunk as separate jobs
for batch in adata_vis.obs['training_batch'].unique():
    # create and train the model
    scvi_run_name = f'{run_name_global}_batch{batch}_seed{seed}'
    print(scvi_run_name)

    training_batch_index = adata_vis_full.obs['training_batch'].isin([batch])
    adata_vis = adata_vis_full[training_batch_index, :].copy()

    # prepare anndata for scVI model
    cell2location.models.Cell2location.setup_anndata(
        adata=adata_vis, batch_key="sample"
    )

    # train as normal

    # export posterior
    import pyro
    # In this section, we export the estimated cell abundance (summary of the posterior distribution).
    adata_vis = mod.export_posterior(
        adata_vis, sample_kwargs={
            'batch_size': int(np.ceil(adata_vis.n_obs / 4)), 'accelerator': 'gpu',
            "return_observed": False,
        },
        add_to_obsm=['q05', 'q50'],
        use_quantiles=True,
    )

Complete the full object with cell abundance estimates from batched analysis

    # copy cell2location results to the main object
    for k in adata_vis_full.obsm.keys():
        adata_vis_full.obsm[k][training_batch_index, :] = adata_vis.obsm[k].copy()
    adata_vis_full.uns[f'mod_{batch}'] = adata_vis.uns['mod'].copy()