BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

High resolution model & distance function prior [draft] #337

Open vitkl opened 1 year ago

vitkl commented 1 year ago

Depends on https://github.com/scverse/scvi-tools/pull/2695

vitkl commented 1 month ago

To enable using total cell abundance estimates from histology images the following changes are necessary (use_proportion_factorisation_prior_on_w_sf = True):

  1. Changing the parameterization of the factorisation prior to produce % of total cell abundance.
  2. Forcing the model to match the provided total cell abundance estimates by using that data as prior with very narrow distribution around the provided values (N_cells_per_location_alpha_prior=1000.0, use_n_s_cells_per_location_limit = True).
  3. Changing detection_alpha=200.0 back to narrow distribution.
  4. Changing other priors.
  5. Code modifications to support N_cells_per_location of shape=(n_obs, 1).

This branch can be installed as follows (I have not tested this particular recipe so please let me know if it doesn't work):

export PYTHONNOUSERSITE="True"
conda create -y -n c2l_v015 python=3.10
conda activate c2l_v015
conda install -y -c anaconda hdf5 pytables git
pip install git+https://github.com/vitkl/scvi-tools.git@pyro_fixes
pip install git+https://github.com/BayraktarLab/cell2location.git@hires_sliding_window[tutorials]
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda activate c2l_v015
python -m ipykernel install --user --name=c2l_v015 --display-name='Environment (c2l_v015)'

Temporary usage instructions. This will become mode='exact_total_cell_abundance' that switches all of these options on:

detection_alpha = 200.0
N_cells_per_location_alpha_prior = 1000.0
use_per_cell_type_normalisation = False
# ideally this is not count of cells 
# but % of spot occupied by cells * 0.9999 quantile of N cells across the data
N_cells_per_location = adata_vis.obs[['n_cell_occupancy']].values.astype('float32')

A_B_per_location_alpha_prior = None
A_factors_per_location = 40.0
B_groups_per_location = 5.0

use_proportion_factorisation_prior_on_w_sf = True
use_n_s_cells_per_location_limit = True

import torch
torch.set_float32_matmul_precision('high')

seed = 0
scvi.settings.seed = seed
np.random.seed(seed)

    # prepare anndata for scVI model
    cell2location.models.Cell2location.setup_anndata(
        adata=adata_vis, batch_key="sample"
    )

    if training:
        import pyro
        mod = cell2location.models.Cell2location(
            adata_vis, cell_state_df=inf_aver, 
            amortised=False,
            N_cells_per_location=N_cells_per_location, # np.array shape (n_obs, 1)
            detection_alpha=detection_alpha,
            use_per_cell_type_normalisation=use_per_cell_type_normalisation,
            N_cells_per_location_alpha_prior=N_cells_per_location_alpha_prior,
            N_cells_mean_var_ratio=None,
            detection_hyp_prior={"mean_alpha": float(1.0)},
            detection_cell_type_prior_alpha=float(100.0),
            A_B_per_location_alpha_prior=A_B_per_location_alpha_prior,
            A_factors_per_location=A_factors_per_location,
            B_groups_per_location=B_groups_per_location,
            use_proportion_factorisation_prior_on_w_sf=use_proportion_factorisation_prior_on_w_sf,
            use_n_s_cells_per_location_limit=use_n_s_cells_per_location_limit,
            n_groups=50,
        ) 

        mod.view_anndata_setup()

        mod.train(max_epochs=80000,
                  # train using full data (batch_size=None)
                  batch_size=None,
                  plan_kwargs={'optim': pyro.optim.Adam(optim_args={'lr': 0.002})},
                  # use all data points in training because
                  # we need to estimate cell abundance at all locations
                  train_size=1,
                  scale_elbo=1 / (adata_vis.n_obs * adata_vis.n_vars),
                  accelerator='gpu')

        # Save model
        mod.save(f"{scvi_run_name}", overwrite=True)
    else:
        # can be loaded later like this:
        mod = cell2location.models.Cell2location.load(f"{scvi_run_name}", adata_vis)

Note that this N_cells_per_location code doesn't support amortised=True.