theislab / chemCPA

Code for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
https://arxiv.org/abs/2204.13545
MIT License
88 stars 23 forks source link

Usage of control cell gene expression in evaluation of chemCPA #111

Closed Dean-98543 closed 1 year ago

Dean-98543 commented 1 year ago

Hello,Thanks for your great work.

when I train the model, I find a question about the phase of evaluation:

In the line 361 in the "chemCPA/experiments_run.py", we can find the evaluation is how to work: it seems that the gene expressions of all control groups in the test set (including all cell lines, such as cell line C1, C2, C3......) are used as the input of the model, and a certain condition (such as the cell line C1, drug D1 and dosage S1) is also used as the input of the model. Then compare the output of the model with the real gene expression of cell C1 treated by drug D1 and dosage S1.

I don't know why the model needs the gene expression of the control group of all cell lines. In my understanding, the model only needs the gene expression of the control group of cell line C1.

Is my understanding currect?

siboehm commented 1 year ago

Here's the line you're referring to:

evaluation_stats["test"] = evaluate_r2(
    self.autoencoder,
    self.datasets["test_treated"],
    self.datasets["test_control"].genes,
)

So what we're passing is:

  1. The model we want to evaluate
  2. A SubDataset containing the gene expressions, drug embedding, covariate embedding, etc for the perturbed cells.
  3. A torch.Tensor containing just the gene expressions of the control cells.

Now we feed the control cell gene expression, the drug embedding and the covariate embedding (both from the SubDataset) to the model and get the prediction. Then we compare the meanOf(model output for this specific condition) to meanOf(test_treated SubDataset.genes for this specific conditions). We do this for all perturbations in the dataset, to get our final score.

Does that make it more clear? The comments in that part of the code could be more helpful, I'll adjust them.

So you are correct, for making a prediction with chemCPA you need just A) the unperturbed gene expression of the cell you care about B) the drug embedding of the relevant drug B) the covariate embedding for the cell line you care about. We pass the dataset including all cell lines and drugs just so that we can loop over them, and evaluate each in turn:

    for cell_drug_dose_comb, category_count in zip(
        *np.unique(dataset.pert_categories, return_counts=True)
    ):
Dean-98543 commented 1 year ago

Thank you very much for your reply!

Sorry for replying you so late. Perhaps my presentation is not clear.

I know the input data of the model and the format for those data. But like you said, when we make a prediction, we need: A) the unperturbed gene expression of the cell we care about B) the drug embedding of the relevant drug C) the covariate embedding for the cell line we care about and in the line 194 in the "chemCPA/train.py", the variable gene_control actually contains the gene expression of the control group of all kinds of cell lines, such as A549, MCF7 and K562. But the emb_covs only contains covariate embeddings for one cell line, e.g. cell line A549.

mean_pred, var_pred = compute_prediction(
            autoencoder,
            genes_control,
            emb_drugs,
            emb_covs,
        )

I only care about one cell line(e.g. A549), but the input of model seems to contain other cells, since gene_control contains three cell lines(A549, MCF7 and K562). What confuses me is that why gene_control contains three cell lines. In my understanding, we only need one cell lines we care about(as said above: A)the unperturbed gene expression of the cell we care about).

siboehm commented 1 year ago

@MxMstrmn do you have a runnable environment to check this? I'm getting numba errors upon import of scanpy on my M1, and I really don't want to dig into that right now. If you don't, I can look into it but not for a few days.

@Dean-98543 There may be a bug in the eval handling of the covariate embedding. No matter what, you definitely don't need to covariate embedding of other cell lines to produces the prediction for one cell line.

Dean-98543 commented 1 year ago

Thank you for your quick reply. From the paper and the code, we have no doubt that covariate embedding of other cell lines are not needed. (We noticed the emb_covs only contains one cell line.)

Our question is, if the gene_control should only contain only the same cell line, or the gene expression of all cell lines in the control group?

MxMstrmn commented 1 year ago

Hi @Dean-98543,

Your observation is perfectly right. I was aware that the genes_control contains all different cell lines. We kept it this way because this is the more challenging setting, as the model has to disentangle well. However, it is perfectly fine to only provide one of the cell lines for the predictions. This should not be harmful but might actually improve predictions.

This might be a good illustration. Instead of providing any control cell you would only provide red, green, or blue:

image
WHan-alter commented 1 year ago

Here's the line you're referring to:

evaluation_stats["test"] = evaluate_r2(
    self.autoencoder,
    self.datasets["test_treated"],
    self.datasets["test_control"].genes,
)

So what we're passing is:

  1. The model we want to evaluate
  2. A SubDataset containing the gene expressions, drug embedding, covariate embedding, etc for the perturbed cells.
  3. A torch.Tensor containing just the gene expressions of the control cells.

Now we feed the control cell gene expression, the drug embedding and the covariate embedding (both from the SubDataset) to the model and get the prediction. Then we compare the meanOf(model output for this specific condition) to meanOf(test_treated SubDataset.genes for this specific conditions). We do this for all perturbations in the dataset, to get our final score.

Does that make it more clear? The comments in that part of the code could be more helpful, I'll adjust them.

So you are correct, for making a prediction with chemCPA you need just A) the unperturbed gene expression of the cell you care about B) the drug embedding of the relevant drug B) the covariate embedding for the cell line you care about. We pass the dataset including all cell lines and drugs just so that we can loop over them, and evaluate each in turn:

    for cell_drug_dose_comb, category_count in zip(
        *np.unique(dataset.pert_categories, return_counts=True)
    ):

Thank you for the nice work. I have a small question regarding the evaluation part, please correct me if I misunderstand something. If you compared the predicted "MeanOf" against another wet-lab measured "MeanOf", then how could you claim you are predicting perturbation at "single-cell" level? It seems to be a population-based evaluation (like bulk)

MxMstrmn commented 1 year ago

Hi @WHan-alter,

Thanks for engaging in the discussion! You are right in the sense that we are computing the mean over predictions. Yet, concerning perturbation effects, you are always interested in the distributional shift that a perturbation applies. While chemCPA predicts the distribution of genes at a single-cell resolution, ultimately, you care for the overall effect that a perturbation applies. The presented metric of R2 accounts for that while other metrics like Waserstein distance would also be interesting to consider. The main difference to a bulk measurement is that chemCPA can deal with single-cell data and learn the induced perturbation distribution, which is impossible from bulk measurements alone.

WHan-alter commented 1 year ago

@MxMstrmn Thank you for the clarification and prompt reply! Yes, the Wasserstein distance sounds better to me, as it operates on the two distributions. The autoencoder produces the mean and variance of a perturbation distribution. Only cares about on the mean value during evaluation made me confused. We could generate the "synthetic cell state" of individual cells using chemCPA with the mean and variances, right?

MxMstrmn commented 1 year ago

Hi @WHan-alter,

Yes, that is correct, and I hoped your analysis worked. If there are no further issues, I would like to close this issue.