Closed maarten-devries closed 1 year ago
Hi @maarten-devries, thanks for the interest and for reporting this. I'll try to fix it asap.
@maarten-devries I think it should be fixed now, can you check?
Yes, that fixed the error. Thank you very much!
Hi @cdedonno,
Towards the end of the tutorial notebook (under Sample Embeddings), there is still one more related issue that arises when using a model loaded from disk.
scpoli_query_loaded.get_conditional_embeddings()
yields ValueError: Length of passed value for obs_names is 17, but this AnnData has shape: (83, 3)
.
To illustrate what's going wrong:
embeddings = scpoli_query.model.embedding.weight.cpu().detach().numpy()
embeddings.shape[0]
is 83
len(scpoli_query.conditions_)
is 83
However, scpoli_query.obs_metadata_.shape[0]
is 17 (should be 83)
For the reference model: len(scpoli_model.conditions_)
is 68 (namely 83-17)
I see that in load_query_data()
, there are some lines that update obs_metadata
:
obs_metadata = attr_dict["obs_metadata_"]
new_obs_metadata = adata.obs.groupby(condition_key).first()
obs_metadata = pd.concat([obs_metadata, new_obs_metadata])
...
new_model.obs_metadata_ = obs_metadata
However, this new obs_metadata
is not retained when the model is saved to disk.
Really appreciate your help!
Thanks for reporting this. Sorry for the inconvenience, I will try to fix it asap.
Should be fixed now, can you check?
It works now, thank you very much for the quick fixes!
Just a note on a specific edge case I ran into (this may not require a fix, but I just wanted to bring the behavior to your attention) : the current implementation doesn't support the case where adata_ref
and adata_query
contain cells from the same batch (in my case, donor_id
).
In my example, I have 2 donor_id
s that occur both in adata_ref
and in adata_query
.
Therefore, len(scpoli_query.conditions_)
is 83
, but scpoli_query.obs_metadata_.shape[0]
is 85
, and as a result, scpoli_query.get_conditional_embeddings()
errors out.
Again, this seems like a minor edge case, but good to be aware of. It could be good to add a simple check that there is no overlap in the condition_key
of adata_query
and adata_ref
.
But I'll close this issue for now :)
Thanks for reporting the edge case, I will fix that as well!
Hi @cdedonno, Thank you for the great package.
I am following the scPoli tutorial and running into an issue with the
.classify()
function, but only if the query model has been saved to disk and then loaded.The code below runs fine if
scpoli_query
has just been trained and never stored to disk.However, when I do
Then the I get the
AssertionError
from the following line:assert self.prototype_training_ is True, f"Model was trained without prototypes"
This must be because
self.prototype_training_
is not saved with the model. Would appreciate your help with this!