theislab / multimil

Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases
https://multimil.rtfd.io/
BSD 3-Clause "New" or "Revised" License
19 stars 3 forks source link

Bag-level predictions do not correspond to unique samples? #70

Open patricks-lab opened 3 months ago

patricks-lab commented 3 months ago

Report

Thanks for the great work!

I'm trying to print out bag-level predictions (i.e. for each donor). I'm following the classification with MIL tutorial (https://multimil.readthedocs.io/en/latest/notebooks/mil_classification.html) and after finishing training and calling mil.get_model_output() this is what adata looks like:

AnnData object with n_obs × n_vars = 359595 × 30 obs: "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ancestry', 'assay', 'cause_of_death', 'cell_type', 'core_or_extension', 'dataset', 'development_stage', 'disease', 'donor_id', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'sample', 'scanvi_label', 'sequencing_platform', 'sex', 'smoking_status', 'study', 'subject_type', 'suspension_type', 'tissue', 'tissue_coarse_unharmonized', 'tissue_detailed_unharmonized', 'tissue_dissociation_protocol', 'tissue_level_2', 'tissue_level_3', 'tissue_sampling_method', 'total_counts', 'ann_level_1_label_final', 'ann_level_2_label_final', 'ann_level_3_label_final', 'ann_level_4_label_final', 'ann_level_5_label_final', 'ref', '_scvi_batch', 'cell_attn', 'bags', 'predicted_disease' uns: '_scvi_uuid', '_scvi_manager_uuid', 'bag_true_disease', 'bag_full_predictions_disease' obsm: 'X_umap', '_scvi_extra_categorical_covs', 'full_predictions_disease'

I'm interested in getting a single prediction/label for each unique sample (namely for each unique value of adata.obs['sample']).

In the tutorial dataset there are 108 unique samples when I print len(np.unique(adata.obs['sample'])).

But when I looked at len(adata.uns['bag_full_predictions_disease']), there are 2816 predictions corresponding to 2816 bags. (Namely, len(np.unique(adata.obs['bags'])) which is 2816). But there should only be 108 unique samples, and hence 108 unique bags.

Is this the right way to get sample-level predictions (i.e. one prediction for each of the 108 unique samples)?

Thanks in advance!

Version information

No response

alitinet commented 1 month ago

Hi @patricks-lab! Thanks for you interest in our work and sorry for the late reply! Exactly, so each sample is split up into multiple bags to allow the training in mini-batches. I would take the average of all bags corresponding to a sample to get a unique sample representation.