Closed M0hammadL closed 2 years ago
Maybe it's better to have just one function after all to get latent representation, cell attention scores, covariate attention scores and predicted classes all at the same time? Because inference can take quite some time for big datasets. And if we have predict separately it's basically just doing the same forward pass as for the latent/attention scores, what do you think @M0hammadL?
i.e. sth like
latent, cell_attn, cov_attn, predicted_class = model.get_model_output()
adata.obsm['latent'] = latent
adata.obs['cell_attn'] = cell_attn
adata.obs['cov_attn'] = cov_attn
adata.obs['predicted_class'] = predicted_class
why not directly modify adata to avoid these calls
adata.obsm['latent'] = latent
adata.obs['cell_attn'] = cell_attn
adata.obs['cov_attn'] = cov_attn
adata.obs['predicted_class'] = predicted_class
put all these things in model.get_model_output
with no return? to avoid later var setting, FYI, scvi models such as scanVI have both get latent and predict but since we have all these things plus + att and etc might be good to have one functional
https://docs.scvi-tools.org/en/stable/tutorials/notebooks/seed_labeling.html
Only to follow scvi-tools api, here e.g. https://docs.scvi-tools.org/en/stable/tutorials/notebooks/totalVI.html
But I can just add a parameter if to do this in place or return
exactly, that would be great, I am a fan of minimal returns since with many covariates and etc could get messy and confusing. btw might be good to have our own accuracy calculator using sklearn, just copy what you have atm, take real labels in data and predicted and report https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
On the patient level right? i.e. if we have 10 patients in query, the predicted label for each patient would be most common class among all bags for that patient, and then compare to true labels -> so here we'd have 10 true, 10 predicted
Or on bag level? then if each patient had 5 bags, then 5*10=50 predicted and 50 true
yeah on patient level, we could also have bag level but the patient level is more interesting but bag level is not defined, you could report cell level, bag level, patient-level but patient level and maybe bag level is the one we would report in the end
add
mil.predict(adata)
and save the labels in adata.obs["predicted_class"]