theislab / multimil

Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases
https://multimil.rtfd.io/
BSD 3-Clause "New" or "Revised" License
19 stars 3 forks source link

add predict function #45

Closed M0hammadL closed 2 years ago

M0hammadL commented 2 years ago

add mil.predict(adata) and save the labels in adata.obs["predicted_class"]

alitinet commented 2 years ago

Maybe it's better to have just one function after all to get latent representation, cell attention scores, covariate attention scores and predicted classes all at the same time? Because inference can take quite some time for big datasets. And if we have predict separately it's basically just doing the same forward pass as for the latent/attention scores, what do you think @M0hammadL?

i.e. sth like

latent, cell_attn, cov_attn, predicted_class = model.get_model_output()
adata.obsm['latent'] = latent
adata.obs['cell_attn'] = cell_attn
adata.obs['cov_attn'] = cov_attn
adata.obs['predicted_class'] = predicted_class
M0hammadL commented 2 years ago

why not directly modify adata to avoid these calls

adata.obsm['latent'] = latent
adata.obs['cell_attn'] = cell_attn
adata.obs['cov_attn'] = cov_attn
adata.obs['predicted_class'] = predicted_class

put all these things in model.get_model_output with no return? to avoid later var setting, FYI, scvi models such as scanVI have both get latent and predict but since we have all these things plus + att and etc might be good to have one functional https://docs.scvi-tools.org/en/stable/tutorials/notebooks/seed_labeling.html

alitinet commented 2 years ago

Only to follow scvi-tools api, here e.g. https://docs.scvi-tools.org/en/stable/tutorials/notebooks/totalVI.html Screenshot 2022-01-27 at 10 46 25

But I can just add a parameter if to do this in place or return

M0hammadL commented 2 years ago

exactly, that would be great, I am a fan of minimal returns since with many covariates and etc could get messy and confusing. btw might be good to have our own accuracy calculator using sklearn, just copy what you have atm, take real labels in data and predicted and report https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html

alitinet commented 2 years ago

On the patient level right? i.e. if we have 10 patients in query, the predicted label for each patient would be most common class among all bags for that patient, and then compare to true labels -> so here we'd have 10 true, 10 predicted

Or on bag level? then if each patient had 5 bags, then 5*10=50 predicted and 50 true

M0hammadL commented 2 years ago

yeah on patient level, we could also have bag level but the patient level is more interesting but bag level is not defined, you could report cell level, bag level, patient-level but patient level and maybe bag level is the one we would report in the end