bunnech / cellot

Learning Single-Cell Perturbation Responses using Neural Optimal Transport
BSD 3-Clause "New" or "Revised" License
109 stars 9 forks source link

2nd Generating prediction after the model is trained. #13

Closed uddamvathanak closed 6 months ago

uddamvathanak commented 6 months ago

Dear Author,

I have taken interest in CellOT package and found it is interesting. After trying it for awhile. I can't get a function to generate prediction based on the train model.

For example, I want to have a different split used for testing and I want to make prediction based on that split instead of random split.

Is it possible to find the function?

Best regards,

Rom Uddamvathanak

bunnech commented 6 months ago

In general, you can follow the code in the evaluation script (https://github.com/bunnech/cellot/blob/main/scripts/evaluate.py). Make sure that you are loading the right model, i.e., in case you use an autoencoder embedding, load both the trained autoencoder and the trained CellOT model. Also make sure that your configs are set correctly, i.e., the config file contains the correct path to the trained autoencoder. Then encoding the data is handled within the load_data function, i.e., specifically in this line. So your test data of interest need to be passed here. Once you loaded all your inputs including models and data, you transport your cells of interest into the perturbed state as done here, and lastly decode your predictions as done in this line.

Hope this helps!

uddamvathanak commented 6 months ago

Hi bunnech,

Thanks for prompt reply. I'm currently using just the CellOT model trained on log normalised gene expression. I found the model.pt and I define a g model and reweighted it based on the trained state to obtain the train g model.

I realised that the min and max range value for the output is quite high

image

here is the code I used to generate the perturbed gene expression. Please correct me where I am wrong.

g = ICNN(input_dim = 1000, hidden_units = [64, 64, 64, 64])
f = ICNN(input_dim = 1000, hidden_units = [64, 64, 64, 64])
f.load_state_dict(model["f_state"])
g.load_state_dict(model["g_state"])
g.eval()
g

# preparing the testing data to generate the output results
train_data = CPUTensor(adata[adata.obs["condition"] == "ctrl"].X)
train_data.requires_grad_(True)

train_results = g.transport(train_data).detach().numpy()
bunnech commented 6 months ago

For gene expression data, you always need to project your data into an autoencoder or PCA space. See the config files for sciplex3 data, for example. Then you run CellOT in the latent space.

uddamvathanak commented 6 months ago

Thanks for the clarification. Is the autoencoder you have mentioned is the latent space from scGen model?

<html>
<body>
<!--StartFragment-->
python ./scripts/train.py \
--
38 | --outdir ./results/scrna-sciplex3/drug-${drug}/model-${model} \
39 | --config ./configs/tasks/sciplex3.yaml \
40 | --config ./configs/models/${model}.yaml \
41 | --config.data.target $drug \
42 | --config.data.ae_emb.path ./results/scrna-sciplex3/drug-${drug}/model-scgen

<!--EndFragment-->
</body>
</html>
bunnech commented 6 months ago

Yes, exactly. So it can be any autoencoder, but we used the same implementation of autoencoder architecture as from scGen to have a better comparison.

uddamvathanak commented 6 months ago

Thank you so much for a clear explanation. I will try the PCA projection and reverse back after the mapping is done.

It seem to me that I need to save the PCA in adata.obsm["X_pca].

Thanks for developing a great tool :D