LiQian-XC / sctour

A deep learning architecture for robust inference and accurate prediction of cellular dynamics
https://sctour.readthedocs.io
MIT License
51 stars 4 forks source link

Reusing trained models #4

Closed GWMcElfresh closed 2 years ago

GWMcElfresh commented 2 years ago

Hi, thank you for writing this package. It's extremely easy to get up and running and seems to do wonderfully on our data.

Admittedly, I'm new to pytorch and I usually stick to using R, so there might be something obvious I'm missing. I'd like to train a model using a dataset (say, sorted cells) then apply the model to predict the latent space/pseudotime to another dataset.

In the tutorials you show how to apply a trained model to new data, but could you point me in the right direction on how load a model after saving it using tnode.save_model()?

LiQian-XC commented 2 years ago

@GWMcElfresh thanks for using scTour. Please see below the code to reuse the saved model.

Firstly, you will train a model using training data and then save the trained model (for example by using the 'nb' mode): tnode = sct.train.Trainer(train_adata, loss_mode='nb') tnode.train() train_adata.obs['ptime'] = tnode.get_time() ## please run the get_time() function before saving the model tnode.save_model('./', 'test_model') ## the first parameter is the directory and the second is the model name prefix

When reusing the saved model to predict the properties of a different dataset:

you need to initialise the Trainer class first in order to use the predict function within it by providing the training adata and training mode

tnode = sct.train.Trainer(train_adata, loss_mode='nb')

then you can use the saved model for another dataset (for example test_adata). Make sure this dataset contains all the genes that are included in the training data.

pred_t, mix_zs, zs, pred_zs = tnode.predict_time(test_adata, get_ltsp = True, model = './test_model.pth') test_adata.obs['ptime'] = pred_t test_adata.obsm['X_TNODE'] = mix_zs

Hope it's clear to you. Please let me know if it does not work.

GWMcElfresh commented 2 years ago

That worked! Thank you.

Reading through the documentation, I thought:

model – The model used to predict the pseudotime. Only provided when using the saved model.

referred to some kind of object loaded in memory, but now I see the string specification in the argument list. Thank you again!