theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
331 stars 51 forks source link

Issue in annotating cell types of unlabelled query data by scPoli #226

Open xr-song opened 8 months ago

xr-song commented 8 months ago

First of all, thank you for developing this impressive toolkit. Much appreciated!

I encountered a problem when using scPoli for cell-type annotation. When I first trained the model on my reference data, and loaded the model for prediction on the query, the annotation result looked fairly reasonable with 8 cell types in total (expected cell types for my query data). However, the second time when I loaded the trained model directly and call the classify function, the result was totally different with only 2 annotated cell types in total. I double checked and rerun it with slight modifications on the parameters but the issue was not resolved. Could you please help me point out the cause of this? Here's the main part of my code:

print('Reading h5ad...')
ref_adata = sc.read_h5ad(path_ref+ref_file)
query_adata = process_query(sc.read(path_query))

common_genes = list(set(query_adata.var_names).intersection(ref_adata.var_names))
ref_adata = ref_adata[:,common_genes]
query_adata = query_adata[:,common_genes]

print('Normalizing reference data...')
sc.pp.normalize_total(ref_adata, target_sum=1e6)
sc.pp.log1p(ref_adata, base=2)
print(ref_adata)

ref_adata.obs['batch'] = ref_adata.obs.sample_id
cell_type_key = ['cell_type','supercluster_term']
condition_key = 'batch'

early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

scpoli_model = scPoli(
    adata=ref_adata,
    condition_keys=condition_key,
    cell_type_keys=cell_type_key,
    embedding_dims=5,
    recon_loss='nb',
)

print('Start training...')
scpoli_model.train(
    n_epochs=50,
    pretraining_epochs=40,
    early_stopping_kwargs=early_stopping_kwargs,
    eta=5,
)

print('Saving model...')
scpoli_model.save(model_dir, overwrite=True, save_anndata=True)

# Process query data
query_adata.obs['cell_type']='unlabeled'
query_adata.obs['supercluster_term']='unlabeled'
print('Normalizing query data...')
sc.pp.normalize_total(query_adata, target_sum=1e6)
sc.pp.log1p(query_adata, base=2)
print(query_adata)

# Load query data and model
scpoli_query = scPoli.load_query_data(
    adata=query_adata,
    reference_model=model_dir,
    labeled_indices=[],
    )

# Train on query
#scpoli_query.train(
#    n_epochs=30,
#    pretraining_epochs=20,
#    eta=10
#)

# Classification
results_dict = scpoli_query.classify(query_adata, scale_uncertainties=True)

# Get latent representation of query data
data_latent= scpoli_query.get_latent(
    query_adata,
    mean=True
)

adata_latent = sc.AnnData(data_latent)
adata_latent.obs = query_adata.obs.copy()

adata_latent.obs['cell_type_pred'] = results_dict['cell_type']['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict['cell_type']['uncert'].tolist()
adata_latent.obs['supercluster_term_pred'] = results_dict['supercluster_term']['preds'].tolist()
adata_latent.obs['supercluster_term_uncert'] = results_dict['supercluster_term']['uncert'].tolist()

adata_latent.obs['cell_type_uncert_pass'] = ['T' if x < 0.2 else 'F' for x in adata_latent.obs['cell_type_uncert']]
adata_latent.obs['supercluster_term_uncert_pass'] = ['T' if x < 0.2 else 'F' for x in adata_latent.obs['supercluster_term_uncert']]

print('Predicted cell types of query:')
print(set(adata_latent.obs.cell_type_pred))

Another question: how does 1) normalization, log transformation, subsetting to highly variable genes, 2) whether to include the step of training on the unlabelled query data influence the resulting model?

Thank you in advance!

yojetsharma commented 1 month ago

Hi, I have a doubt regarding the code above: I’m trying to map my query which doesn’t have the same obs as the reference, so how can I go about it? The only thing that can be same in my query and reference is Leiden (in query) and CellClass (in ref).