theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
331 stars 51 forks source link

"Model was trained without prototypes" when using a loaded scPoli model #176

Closed maarten-devries closed 1 year ago

maarten-devries commented 1 year ago

Hi @cdedonno, Thank you for the great package.

I am following the scPoli tutorial and running into an issue with the .classify() function, but only if the query model has been saved to disk and then loaded.

The code below runs fine if scpoli_query has just been trained and never stored to disk.

results_dict = scpoli_query.classify(
    adata_query.X,
    adata_query.obs[condition_key]
)

However, when I do

scpoli_query.save(query_model_path)
scpoli_query_loaded = scPoli.load(query_model_path, adata=adata_query)

results_dict = scpoli_query_loaded.classify(
    adata_query.X,
    adata_query.obs[condition_key],
)

Then the I get the AssertionError from the following line: assert self.prototype_training_ is True, f"Model was trained without prototypes"

This must be because self.prototype_training_ is not saved with the model. Would appreciate your help with this!

cdedonno commented 1 year ago

Hi @maarten-devries, thanks for the interest and for reporting this. I'll try to fix it asap.

cdedonno commented 1 year ago

@maarten-devries I think it should be fixed now, can you check?

maarten-devries commented 1 year ago

Yes, that fixed the error. Thank you very much!

maarten-devries commented 1 year ago

Hi @cdedonno,

Towards the end of the tutorial notebook (under Sample Embeddings), there is still one more related issue that arises when using a model loaded from disk.

scpoli_query_loaded.get_conditional_embeddings() yields ValueError: Length of passed value for obs_names is 17, but this AnnData has shape: (83, 3).

To illustrate what's going wrong: embeddings = scpoli_query.model.embedding.weight.cpu().detach().numpy()

embeddings.shape[0] is 83 len(scpoli_query.conditions_) is 83 However, scpoli_query.obs_metadata_.shape[0] is 17 (should be 83)

For the reference model: len(scpoli_model.conditions_) is 68 (namely 83-17)

I see that in load_query_data(), there are some lines that update obs_metadata:

obs_metadata = attr_dict["obs_metadata_"]
new_obs_metadata = adata.obs.groupby(condition_key).first()
obs_metadata = pd.concat([obs_metadata, new_obs_metadata])
...
new_model.obs_metadata_ = obs_metadata

However, this new obs_metadata is not retained when the model is saved to disk.

Really appreciate your help!

cdedonno commented 1 year ago

Thanks for reporting this. Sorry for the inconvenience, I will try to fix it asap.

cdedonno commented 1 year ago

Should be fixed now, can you check?

maarten-devries commented 1 year ago

It works now, thank you very much for the quick fixes!

Just a note on a specific edge case I ran into (this may not require a fix, but I just wanted to bring the behavior to your attention) : the current implementation doesn't support the case where adata_ref and adata_query contain cells from the same batch (in my case, donor_id). In my example, I have 2 donor_ids that occur both in adata_ref and in adata_query.

Therefore, len(scpoli_query.conditions_) is 83, but scpoli_query.obs_metadata_.shape[0] is 85, and as a result, scpoli_query.get_conditional_embeddings() errors out.

Again, this seems like a minor edge case, but good to be aware of. It could be good to add a simple check that there is no overlap in the condition_key of adata_query and adata_ref. But I'll close this issue for now :)

cdedonno commented 1 year ago

Thanks for reporting the edge case, I will fix that as well!