theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
326 stars 51 forks source link

loading scArches from local and train with new data #19

Closed notimenocall closed 3 years ago

notimenocall commented 3 years ago

Hi team, I've trained a network successfully following the tutorial and the network is saved in local. When I reload the network and train with a query data set with the train() function, my kernel restarts itself after a couple of warning messages from tensorflow: _The name tf.assign is deprecated. Please use tf.compat.v1.assign instead. Not sure if this is related to scArches, but have you seen before? Since there is no tutorial for loading network from local, I also like to confirm what I did is correct. Any advice is appreciated!

` config_path = './models/scArches_v4_sse/IBD_epi_reference/scArches.json'

pre_trained_scArches = sca.models.scArches.from_config(config_path, construct=True, compile=True)

pre_trained_scArches.model_path = './models/scArches_v4_sse/IBD_epi_reference/'

pre_trained_scArches.task_name = 'IBD_epi_reference'

pre_trained_scArches.restore_model_weights(compile=True)

target_conditions = adata.obs[condition_key].unique().tolist()

new_network = sca.operate(pre_trained_scArches, new_conditions=target_conditions, new_task_name="IBD_epi_UC_CD")

new_network.model_path = './models/scArches_v4_sse/IBD_epi_UC_CD/'

new_network.train(adata, condition_key=condition_key, batch_size=128, n_epochs=100) ### this is where the kernal restarts `

M0hammadL commented 3 years ago

Hi,

I have not seen this problem, does this happen when you train query data or also when you train the reference?

I just tried to train a model and save the model and restore the model and then add query data. However, I did not get such an error :

`network = sca.models.scArches(task_name='pancreas_reference', x_dimension=adata.shape[1], z_dimension=10, architecture=[128, 128], gene_names=adata.var_names.tolist(), conditions=reference_batch_labels, alpha=0.001, loss_fn='nb', model_path="./models/scArches/", seed = 20 )

network.train(reference_adata, n_epochs=10, condition_key=condition_key, batch_size=128,)

network.restore_model_weights()

new_network = sca.operate(network, new_task_name="pancreas_query", new_conditions=target_conditions)

new_network.train(query_adata, condition_key=condition_key, batch_size=128, n_epochs=40)

`

notimenocall commented 3 years ago

Thank you very much for testing this. I just tried to a couple things and it seems like the issue is related to the "loss_fn" option. If loss_fn = 'nb', it works fine. If loss_fn = "sse", the issue came up: the first network.train with the reference data works (31k cells), but the 2nd training with the query failed (2.7k cells).

notimenocall commented 3 years ago

There may also be a bug when running with loss_fn = "zinb". network.train() returns an error: 'Activation' object has no attribute '__name__'