bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
997 stars 190 forks source link

Perturbation: keeping within-batch cell pairing #57

Open krejciadam opened 1 year ago

krejciadam commented 1 year ago

Hi! I was wondering about fine-tuning the model for perturbation prediction on larger datasets containing multiple batches. Different batches can be e.g. different cell lines. Each perturbation can exist in multiple batches, so for the training, we need to ensure that control and perturbed cells are only paired within the same batch. Is there a way to achieve this?

From GEARS

https://github.com/snap-stanford/GEARS/blob/master/gears/pertdata.py

it seems to me that this is not implemented, but I might be wrong.

In case I'm not wrong, I can think of an emergency solution where I'd split my dataset by batch (or maybe even more, split it to many sub-batches) and then manually train it always for one epoch on each of these splits. This feels a little desperate though. Do you perhaps have any suggestions for a better solution? Thank you!

subercui commented 1 year ago

Hi @krejciadam , thank you, and that sounds like an interesting new application!

I think you are right about pairing the cells within batches/cell lines. Currently, I don't think there is an implementation for that yet. Your plan sounds doable. On the other hand, I would suggest another strategy that may be easier to implement:

  1. separating the data first into separate datafiles per batch. 2. Preprocessing them separately using the current pertdata class. The current processing implementation will eventually put them in pytorch geometric DataLoaders 3. Have a wrapper loader to randomly pick data from each DataLoader. This is the step different from your initial strategy: I don't think you need to separate cell lines into different training epochs? You can train them mixed together, but it should be fine to separate them if you have other considerations.

I hope these could help your implementation.

BTW, if you'd like to share with others, I do think this can be a valuable extension! Feel free to open a PR for this.

krejciadam commented 1 year ago

Hi!
Thank you for the suggestion. This is indeed faster than changing all the GEARS routines to achieve the behavior. Although I will anyway need to do it, because the GEARS train/validation/split logic is insufficient for this and other scenarios.

FYI one of the motivations for this was to use scGPT to "transfer" perturbations to new cell types and for now it seems to utterly fail in this task, unfortunately. Say I have a dataset consisting of the same set of perturbations in 3 different cell lines A, B and C. I train the model on A and B, using the correct pairing as mentioned. Now I want to input wild-type cells of type C and predict their perturbed phenotype. The model has seen all of the perturbations I want to predict in cell types A and B during training, but has seen no cell of cell type C. Shows up the output is driven only by the perturbation chosen and disregards the input transcriptome pretty much completely. What I mean is this - perturbed transcriptomes of a cell line are typically still pretty similar to wild-type transcriptomes of that cell line. Perturbations do not change the cell identity completely. However, for any perturbation, the predicted transcriptome based on wild-type cells of cell line C is similar to wild-type transcriptomes of cell lines A and B. The cell identity of C is disregarded and the model just outputs what it has seen for A and B. I can use any cell type D, E... for prediction and the result is quite similar. I understand this was not the original intention, so maybe the objective would need to be defined differently for this kind of tasks, although as of now, I'm not sure how.