Open vitkl opened 1 year ago
To enable using total cell abundance estimates from histology images the following changes are necessary (use_proportion_factorisation_prior_on_w_sf = True
):
N_cells_per_location_alpha_prior=1000.0
, use_n_s_cells_per_location_limit = True
).detection_alpha=200.0
back to narrow distribution.N_cells_per_location
of shape=(n_obs, 1)
.This branch can be installed as follows (I have not tested this particular recipe so please let me know if it doesn't work):
export PYTHONNOUSERSITE="True"
conda create -y -n c2l_v015 python=3.10
conda activate c2l_v015
conda install -y -c anaconda hdf5 pytables git
pip install git+https://github.com/vitkl/scvi-tools.git@pyro_fixes
pip install git+https://github.com/BayraktarLab/cell2location.git@hires_sliding_window[tutorials]
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda activate c2l_v015
python -m ipykernel install --user --name=c2l_v015 --display-name='Environment (c2l_v015)'
Temporary usage instructions. This will become mode='exact_total_cell_abundance'
that switches all of these options on:
detection_alpha = 200.0
N_cells_per_location_alpha_prior = 1000.0
use_per_cell_type_normalisation = False
# ideally this is not count of cells
# but % of spot occupied by cells * 0.9999 quantile of N cells across the data
N_cells_per_location = adata_vis.obs[['n_cell_occupancy']].values.astype('float32')
A_B_per_location_alpha_prior = None
A_factors_per_location = 40.0
B_groups_per_location = 5.0
use_proportion_factorisation_prior_on_w_sf = True
use_n_s_cells_per_location_limit = True
import torch
torch.set_float32_matmul_precision('high')
seed = 0
scvi.settings.seed = seed
np.random.seed(seed)
# prepare anndata for scVI model
cell2location.models.Cell2location.setup_anndata(
adata=adata_vis, batch_key="sample"
)
if training:
import pyro
mod = cell2location.models.Cell2location(
adata_vis, cell_state_df=inf_aver,
amortised=False,
N_cells_per_location=N_cells_per_location, # np.array shape (n_obs, 1)
detection_alpha=detection_alpha,
use_per_cell_type_normalisation=use_per_cell_type_normalisation,
N_cells_per_location_alpha_prior=N_cells_per_location_alpha_prior,
N_cells_mean_var_ratio=None,
detection_hyp_prior={"mean_alpha": float(1.0)},
detection_cell_type_prior_alpha=float(100.0),
A_B_per_location_alpha_prior=A_B_per_location_alpha_prior,
A_factors_per_location=A_factors_per_location,
B_groups_per_location=B_groups_per_location,
use_proportion_factorisation_prior_on_w_sf=use_proportion_factorisation_prior_on_w_sf,
use_n_s_cells_per_location_limit=use_n_s_cells_per_location_limit,
n_groups=50,
)
mod.view_anndata_setup()
mod.train(max_epochs=80000,
# train using full data (batch_size=None)
batch_size=None,
plan_kwargs={'optim': pyro.optim.Adam(optim_args={'lr': 0.002})},
# use all data points in training because
# we need to estimate cell abundance at all locations
train_size=1,
scale_elbo=1 / (adata_vis.n_obs * adata_vis.n_vars),
accelerator='gpu')
# Save model
mod.save(f"{scvi_run_name}", overwrite=True)
else:
# can be loaded later like this:
mod = cell2location.models.Cell2location.load(f"{scvi_run_name}", adata_vis)
Note that this N_cells_per_location
code doesn't support amortised=True
.
Depends on https://github.com/scverse/scvi-tools/pull/2695