BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
315 stars 57 forks source link

"max_epochs=30000" in cell2location mod.train() make "-ELBO loss" up and down when passing 15000 epochs #327

Closed sciencepeak closed 11 months ago

sciencepeak commented 12 months ago

Problem

"max_epochs=30000" in cell2location mod.train() make "-ELBO loss" up and down when passing 15000 epochs. Does this suggest the over-fitting or other explainable reasons?

The phenomenon is shown in the official tutorial that I follow: https://cell2location.readthedocs.io/en/latest/notebooks/cell2location_tutorial.html "Training cell2location" section notebooks_cell2location_tutorial_40_4

The phenomenon is also shown in my own 10X Visium data, with the shaking more obvious Elbo Cell2locationTraining

Description of the data input and hyperparameters

10X visium slide of from the liver.

N_cells_per_location=10,
detection_alpha=20

Single cell reference data: number of cells, number of cell types, number of genes

AnnData object with n_obs × n_vars = 11157 × 9006 obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Sample_id', 'Patient_id', 'Mount_batch', 'Mets_primary', 'Lesion_site', 'dataset', 'platform', 'percent.mt', 'integrated_snn_res.0.5', 'integrated_snn_res.0.8', 'integrated_snn_res.1', 'integrated_snn_res.1.2', 'cell_type', '_indices', '_scvi_batch', '_scvi_labels' var: 'features', 'n_cells', 'nonz_mean' uns: '_scvi_manager_uuid', '_scvi_uuid', 'mod' obsm: '_scvi_extra_categorical_covs' varm: 'means_per_cluster_mu_fg', 'q05_per_cluster_mu_fg', 'q95_per_cluster_mu_fg', 'stds_per_cluster_mu_fg'

Single cell reference data: technology type (e.g. mix of 10X 3' and 5')

10X single cell RNA-seq

Spatial data: number of locations numbers, technology type (e.g. Visium, ISS, Nanostring WTA)

10X Visium

AnnData object with n_obs × n_vars = 2988 × 8673 obs: 'in_tissue', 'array_row', 'array_col', 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'MT_frac', '_indices', '_scvi_batch', '_scvi_labels', 'B cell and plasma cell', 'CD8 T cell', 'EPCAM+ cholangiocytes and LPC', 'NK cell', 'adipocyte or smooth muscle-like', 'hepatic stellate cell-1', 'hepatic stellate cell-2', 'hepatocyte-1', 'liver endothelial', 'macrophage-1', 'macrophage-2', 'macrophage-3', 'melanoma ISG high Liver', 'melanoma-1 Liver', 'melanoma-3 Liver', 'naïve or memory T cell', 'neutrophil', 'proliferating cell', 'leiden', 'region_cluster', 'mean_nUMI_factorsfact_0', 'mean_nUMI_factorsfact_1', 'mean_nUMI_factorsfact_2', 'mean_nUMI_factorsfact_3', 'mean_nUMI_factorsfact_4', 'mean_nUMI_factorsfact_5', 'mean_nUMI_factorsfact_6', 'mean_nUMI_factorsfact_7', 'mean_nUMI_factorsfact_8', 'mean_nUMI_factorsfact_9', 'mean_nUMI_factorsfact_10', 'mean_nUMI_factorsfact_11', 'mean_nUMI_factorsfact_12', 'mean_nUMI_factorsfact_13', 'mean_nUMI_factorsfact_14', 'mean_nUMI_factorsfact_15', 'mean_nUMI_factorsfact_16', 'mean_nUMI_factorsfact_17', 'mean_nUMI_factorsfact_18', 'mean_nUMI_factorsfact_19', 'mean_nUMI_factorsfact_20', 'mean_nUMI_factorsfact_21', 'mean_nUMI_factorsfact_22', 'mean_nUMI_factorsfact_23', 'mean_nUMI_factorsfact_24', 'mean_nUMI_factorsfact_25', 'mean_nUMI_factorsfact_26', 'mean_nUMI_factorsfact_27', 'mean_nUMI_factorsfact_28', 'mean_nUMI_factorsfact_29' var: 'gene_ids', 'feature_types', 'genome', 'SYMBOL', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'MT' uns: '_scvi_manager_uuid', '_scvi_uuid', 'leiden', 'mod', 'mod_coloc_n_fact10', 'mod_coloc_n_fact11', 'mod_coloc_n_fact12', 'mod_coloc_n_fact13', 'mod_coloc_n_fact14', 'mod_coloc_n_fact15', 'mod_coloc_n_fact16', 'mod_coloc_n_fact17', 'mod_coloc_n_fact18', 'mod_coloc_n_fact19', 'mod_coloc_n_fact20', 'mod_coloc_n_fact21', 'mod_coloc_n_fact22', 'mod_coloc_n_fact23', 'mod_coloc_n_fact24', 'mod_coloc_n_fact25', 'mod_coloc_n_fact26', 'mod_coloc_n_fact27', 'mod_coloc_n_fact28', 'mod_coloc_n_fact29', 'mod_coloc_n_fact30', 'mod_coloc_n_fact5', 'mod_coloc_n_fact6', 'mod_coloc_n_fact7', 'mod_coloc_n_fact8', 'mod_coloc_n_fact9', 'neighbors', 'region_cluster_colors', 'sample_colors', 'spatial', 'umap' obsm: 'MT', 'X_umap', 'means_cell_abundance_w_sf', 'q05_cell_abundance_w_sf', 'q95_cell_abundance_w_sf', 'spatial', 'stds_cell_abundance_w_sf' layers: 'B cell and plasma cell', 'CD8 T cell', 'EPCAM+ cholangiocytes and LPC', 'NK cell', 'adipocyte or smooth muscle-like', 'hepatic stellate cell-1', 'hepatic stellate cell-2', 'hepatocyte-1', 'liver endothelial', 'macrophage-1', 'macrophage-2', 'macrophage-3', 'melanoma ISG high Liver', 'melanoma-1 Liver', 'melanoma-3 Liver', 'naïve or memory T cell', 'neutrophil', 'proliferating cell' obsp: 'connectivities', 'distances'

vitkl commented 11 months ago

Hi @sciencepeak

We have seen higher ELBO stochasticity towards the end of training for some datasets and no such increases for others. Higher stochasticity suggests numerical stability issues at least for some parameters in the model (could be very few, like 10 out of 100k+) - with more variability when more parameters are affected. It could be proportional to data quality, reference-spatial data mismatch, using too simple reference or some other reasons.

We don't see clear issues in cell2location cell type abundance results arising from training long enough to see stochasticity. Its possible that this is only an issue for the affected, hard-to-estimated parameters, but not for vast majority of cell2location results. I would not recommend stopping training at 15k epochs for either plot you reference.