theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
335 stars 51 forks source link

query and reference integration fails #254

Open SarahAsbury opened 3 weeks ago

SarahAsbury commented 3 weeks ago

I have been unable to have successful integration across 3 separate datasets with my reference data. Both the query and reference only contain T cells. My reference dataset was derived with TotalVI (i.e. contained RNA + protein) whereas each query is scRNA-seq data only. Image

Are there any recommendations on what might be the issue?

SarahAsbury commented 3 weeks ago

For reference, the script is below:

# ----- libraries -----
import os
import scanpy as sc
import anndata
import torch
import scarches as sca
import matplotlib.pyplot as plt
import numpy as np
import scvi as scv
import pandas as pd
import argparse
from pathlib import Path

# ----- sesssion params -----
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

# ----- params -----
parser = argparse.ArgumentParser(description='Run scArches. Provide either refernece modelDir or path to each model component. i.e. model, modelGenes, modelProteins, and modelAdata.')

parser.add_argument("--h5", 
                    help = "Full path to h5 file that will be integrated to model.",
                    required = True
                    )
parser.add_argument("--modelDir",
                    help = "Full path to directory contaning model, modelGenes, modelProtein, and modelAdata."
                    )
parser.add_argument("--model",
                    help = "Full path to the stored model for integration."
                    )
parser.add_argument("--modelGenes",
                    help = "Full path to csv containing genes used in the model."
                    )
parser.add_argument("--modelProteins",
                    help = "Full path to csv containing proteins used in the model.",
                    default = None)
parser.add_argument("--modelAdata",
                    help = "Full path to anndata used to train the model")
parser.add_argument("--outputPrefix",
                    help = "Prefix to be appended to output files.",
                    required = True
                    )
parser.add_argument("--output",
                    help = "Full path to output directory",
                    required = True
                    )
parser.add_argument("--imputeProtein",
                    help = "If flag provided, proteins will be imputed. Only viable for TotalVI models.",
                    action = "store_true"
                    )
parser.add_argument("--sampleCol",
                    help = "Name of sample column in h5/seurat data.",
                    required = True)
parser.add_argument("--annotCol",
                    help = "Name of annotation column in h5/seurat data.",
                    required = True)
parser.add_argument("--modelSampleCol",
                    help = "Name of sample column in original model adata.",
                    required = True)

# # for run
params = parser.parse_args()

# reference model paths
if params.modelDir is not None:
    modelDir = Path(params.modelDir)
    params.model = next(modelDir.glob("*model")).as_posix()
    params.modelAdata = next(modelDir.glob("*adata.h5ad")).as_posix()
    params.modelGenes = next(modelDir.glob("*genes.csv")).as_posix()
    params.modelProteins = next(modelDir.glob("*proteins.csv")).as_posix()

print(params)

# ----- internal params -----
# scvi settings
arches_params = dict(
    use_layer_norm="none",
    use_batch_norm="both",
) # https://docs.scvi-tools.org/en/stable/api/reference/scvi.module.TOTALVAE.html
# these have been set to totalVI default settings as that is what was used to generate the models. note that it is reversed in the scArches tutorial
scv.settings.seed = 21022024

# output fn
plot_dir = os.path.join(params.output, "plots")
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

output_model_dir = os.path.join(params.output, "{}_totalvi_model".format(params.outputPrefix))
if not os.path.exists(output_model_dir):
    os.makedirs(output_model_dir)

output_fn = {
    # results
    "h5": os.path.join(params.output, "{}_scArches.h5ad".format(params.outputPrefix)),
    # plots
    "umap_annot": os.path.join(plot_dir, "{}_scArches_annot_umap.tiff".format(params.outputPrefix)),
    "umap_samples": os.path.join(plot_dir, "{}_scArches_samples_umap.tiff".format(params.outputPrefix)),
    "umap_integrated": os.path.join(plot_dir, "{}_scArches_integrated_umap.tiff".format(params.outputPrefix)),
    "umap_integrated_annot": os.path.join(plot_dir, "{}_scArches_integrated_annot_umap.tiff".format(params.outputPrefix)),
    # model
    "model": os.path.join(output_model_dir, "{}_totalvi_model".format(params.outputPrefix)),
    "model_adata": os.path.join(output_model_dir, "{}_model_adata.h5ad".format(params.outputPrefix)),
    "model_genes": os.path.join(output_model_dir, os.path.basename(params.modelGenes)),
    "model_proteins": os.path.join(output_model_dir, os.path.basename(params.modelProteins))
}
print(output_fn)

# ----- import data -----
# adata 
adata_import = sc.read_h5ad(params.h5)
adata_import.var['original_var_names'] = adata_import.var_names
adata_import.var_names = adata_import.var['Symbol']
print(adata_import)

# genes
genes_import = pd.read_csv(params.modelGenes)

# proteins
proteins_import = pd.read_csv(params.modelProteins)

# reference
adata_ref = sc.read_h5ad(params.modelAdata)

# ----- clean data -----
# genes
genes = genes_import[genes_import["highly_variable"]]['gene'] 
missing_genes = genes[~(genes.isin(adata_import.var['Symbol']))]

# proteins
proteins = proteins_import["protein"] 

# adata
adata = adata_import.copy()

# ----- prepare adata -----
### add missing reference genes
# missing genes variables
add_genes_var = pd.DataFrame(missing_genes).rename(columns = {"gene":"Symbol"})
add_genes_var["GENENAME"] = None
add_genes_var["SEQNAME"] = None

# missing genes X
add_genes_X = np.zeros((adata.n_obs, len(missing_genes)))

# create missing genes anndata
add_genes = anndata.AnnData(X = add_genes_X, var = add_genes_var)
add_genes.obs.index = adata.obs.index
add_genes.var.index = missing_genes
print(add_genes)

# add missing genes to original anndata 
adata = anndata.concat([adata, add_genes], axis=1, join='outer')

# add metadata
if (adata_import.obs.index.equals(adata.obs.index)):
    adata.obs = adata_import.obs
else:
    raise ValueError("The order of observations has changed.")

### prepare adata
# filter adata to model genes
adata = adata[:, adata.var["Symbol"].isin(genes)] # filter genes
adata = adata[:, genes] # reorder

# add sample column
adata.obs[params.modelSampleCol] = adata.obs[params.sampleCol]

print(adata)

# ----- add proteins -----
if params.imputeProtein is True:
    # adata add proteins 
    data = np.zeros((adata.n_obs, len(proteins)))
    adata.obsm["protein"] = data
    print(adata)
else: 
    raise Exception ("Support has not yet been developed for andata with protein data. Please use imputeProtein if appropriate or develop this script.")

# ----- align query and reference -----
adata.obs["ref"] = "query"
adata_ref.obs["ref"] = "ref"

adata_full = anndata.concat([adata_ref, adata])
adata_ref = adata_full[adata_full.obs.ref == "ref"].copy()
adata = adata_full[adata_full.obs.ref == "query"].copy()

# ----- model -----

scv.model.TOTALVI.prepare_query_anndata(adata, 
                                        params.model,
                                        return_reference_var_names=True)
print(adata)

# setup model
vae_q = sca.models.TOTALVI.load_query_data(
    adata,
    params.model,
    inplace_subset_query_vars = True,
    freeze_expression=True # freezes neurons in the first layer for TL 
)

# # run model
vae_q.train(200, plan_kwargs=dict(weight_decay=0.0))

# ----- query data results ------
### latent space
adata.obsm["X_totalVI"] = vae_q.get_latent_representation()

### umaps
sc.pp.neighbors(adata, use_rep="X_totalVI")
sc.tl.umap(adata, min_dist=0.4)

# plot annot
sc.pl.umap(
    adata,
    color=[params.annotCol],
    frameon=False,
    ncols=1,
    title = "Query Annotations"
)
plt.savefig(output_fn["umap_annot"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})

# plot samples 
sc.pl.umap(
    adata,
    color=[params.modelSampleCol],
    frameon=False,
    ncols=1,
    title = "Query Samples"
)
plt.savefig(output_fn["umap_samples"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})

# ----- integrated results -----
### latent space
adata_full_new = adata_full.copy()
adata_full_new.obsm["X_totalVI"] = vae_q.get_latent_representation(adata_full_new)

### umaps
sc.pp.neighbors(adata_full_new, use_rep="X_totalVI")
sc.tl.umap(adata_full_new, min_dist=0.3)

# plot query/ref
sc.pl.umap(
    adata_full_new,
    color=["ref"],
    frameon=False,
    ncols=1,
    title="Reference and query"
)
plt.savefig(output_fn["umap_integrated"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})

# plot annot
sc.pl.umap(
    adata_full_new,
    color=[params.annotCol],
    frameon=False,
    ncols=1,
    title="Reference and query"
)
plt.savefig(output_fn["umap_integrated_annot"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})

### get imputed and/or normalized proteins
# find non-zero reference samples
ref_protein_samples = adata_full_new[adata_full_new.obsm["protein"].sum(axis = 1) != 0].obs[params.modelSampleCol].unique()

_, imp_norm_proteins = vae_q.get_normalized_expression(
    adata_full_new,
    n_samples=25,
    return_mean=True,
    transform_batch=ref_protein_samples
)
imp_norm_proteins.columns = proteins
adata_full_new.obsm["impute_or_norm_protein"] = imp_norm_proteins

# ----- save results -----
if 'discard' in adata_full_new.obs.columns:
    adata_full_new.obs["discard"] = str(adata.obs["discard"])

adata_full_new.write(output_fn["h5"])

# ----- save model -----
# updated model 
vae_q.save(output_fn["model"], overwrite = True)

# updated model adata
adata_full_new.write(output_fn["model_adata"])

# symlink genes 
os.symlink(params.modelGenes, output_fn["model_genes"])

# symlink protein
os.symlink(params.modelProteins, output_fn["model_proteins"])