Open SarahAsbury opened 3 weeks ago
For reference, the script is below:
# ----- libraries -----
import os
import scanpy as sc
import anndata
import torch
import scarches as sca
import matplotlib.pyplot as plt
import numpy as np
import scvi as scv
import pandas as pd
import argparse
from pathlib import Path
# ----- sesssion params -----
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
# ----- params -----
parser = argparse.ArgumentParser(description='Run scArches. Provide either refernece modelDir or path to each model component. i.e. model, modelGenes, modelProteins, and modelAdata.')
parser.add_argument("--h5",
help = "Full path to h5 file that will be integrated to model.",
required = True
)
parser.add_argument("--modelDir",
help = "Full path to directory contaning model, modelGenes, modelProtein, and modelAdata."
)
parser.add_argument("--model",
help = "Full path to the stored model for integration."
)
parser.add_argument("--modelGenes",
help = "Full path to csv containing genes used in the model."
)
parser.add_argument("--modelProteins",
help = "Full path to csv containing proteins used in the model.",
default = None)
parser.add_argument("--modelAdata",
help = "Full path to anndata used to train the model")
parser.add_argument("--outputPrefix",
help = "Prefix to be appended to output files.",
required = True
)
parser.add_argument("--output",
help = "Full path to output directory",
required = True
)
parser.add_argument("--imputeProtein",
help = "If flag provided, proteins will be imputed. Only viable for TotalVI models.",
action = "store_true"
)
parser.add_argument("--sampleCol",
help = "Name of sample column in h5/seurat data.",
required = True)
parser.add_argument("--annotCol",
help = "Name of annotation column in h5/seurat data.",
required = True)
parser.add_argument("--modelSampleCol",
help = "Name of sample column in original model adata.",
required = True)
# # for run
params = parser.parse_args()
# reference model paths
if params.modelDir is not None:
modelDir = Path(params.modelDir)
params.model = next(modelDir.glob("*model")).as_posix()
params.modelAdata = next(modelDir.glob("*adata.h5ad")).as_posix()
params.modelGenes = next(modelDir.glob("*genes.csv")).as_posix()
params.modelProteins = next(modelDir.glob("*proteins.csv")).as_posix()
print(params)
# ----- internal params -----
# scvi settings
arches_params = dict(
use_layer_norm="none",
use_batch_norm="both",
) # https://docs.scvi-tools.org/en/stable/api/reference/scvi.module.TOTALVAE.html
# these have been set to totalVI default settings as that is what was used to generate the models. note that it is reversed in the scArches tutorial
scv.settings.seed = 21022024
# output fn
plot_dir = os.path.join(params.output, "plots")
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
output_model_dir = os.path.join(params.output, "{}_totalvi_model".format(params.outputPrefix))
if not os.path.exists(output_model_dir):
os.makedirs(output_model_dir)
output_fn = {
# results
"h5": os.path.join(params.output, "{}_scArches.h5ad".format(params.outputPrefix)),
# plots
"umap_annot": os.path.join(plot_dir, "{}_scArches_annot_umap.tiff".format(params.outputPrefix)),
"umap_samples": os.path.join(plot_dir, "{}_scArches_samples_umap.tiff".format(params.outputPrefix)),
"umap_integrated": os.path.join(plot_dir, "{}_scArches_integrated_umap.tiff".format(params.outputPrefix)),
"umap_integrated_annot": os.path.join(plot_dir, "{}_scArches_integrated_annot_umap.tiff".format(params.outputPrefix)),
# model
"model": os.path.join(output_model_dir, "{}_totalvi_model".format(params.outputPrefix)),
"model_adata": os.path.join(output_model_dir, "{}_model_adata.h5ad".format(params.outputPrefix)),
"model_genes": os.path.join(output_model_dir, os.path.basename(params.modelGenes)),
"model_proteins": os.path.join(output_model_dir, os.path.basename(params.modelProteins))
}
print(output_fn)
# ----- import data -----
# adata
adata_import = sc.read_h5ad(params.h5)
adata_import.var['original_var_names'] = adata_import.var_names
adata_import.var_names = adata_import.var['Symbol']
print(adata_import)
# genes
genes_import = pd.read_csv(params.modelGenes)
# proteins
proteins_import = pd.read_csv(params.modelProteins)
# reference
adata_ref = sc.read_h5ad(params.modelAdata)
# ----- clean data -----
# genes
genes = genes_import[genes_import["highly_variable"]]['gene']
missing_genes = genes[~(genes.isin(adata_import.var['Symbol']))]
# proteins
proteins = proteins_import["protein"]
# adata
adata = adata_import.copy()
# ----- prepare adata -----
### add missing reference genes
# missing genes variables
add_genes_var = pd.DataFrame(missing_genes).rename(columns = {"gene":"Symbol"})
add_genes_var["GENENAME"] = None
add_genes_var["SEQNAME"] = None
# missing genes X
add_genes_X = np.zeros((adata.n_obs, len(missing_genes)))
# create missing genes anndata
add_genes = anndata.AnnData(X = add_genes_X, var = add_genes_var)
add_genes.obs.index = adata.obs.index
add_genes.var.index = missing_genes
print(add_genes)
# add missing genes to original anndata
adata = anndata.concat([adata, add_genes], axis=1, join='outer')
# add metadata
if (adata_import.obs.index.equals(adata.obs.index)):
adata.obs = adata_import.obs
else:
raise ValueError("The order of observations has changed.")
### prepare adata
# filter adata to model genes
adata = adata[:, adata.var["Symbol"].isin(genes)] # filter genes
adata = adata[:, genes] # reorder
# add sample column
adata.obs[params.modelSampleCol] = adata.obs[params.sampleCol]
print(adata)
# ----- add proteins -----
if params.imputeProtein is True:
# adata add proteins
data = np.zeros((adata.n_obs, len(proteins)))
adata.obsm["protein"] = data
print(adata)
else:
raise Exception ("Support has not yet been developed for andata with protein data. Please use imputeProtein if appropriate or develop this script.")
# ----- align query and reference -----
adata.obs["ref"] = "query"
adata_ref.obs["ref"] = "ref"
adata_full = anndata.concat([adata_ref, adata])
adata_ref = adata_full[adata_full.obs.ref == "ref"].copy()
adata = adata_full[adata_full.obs.ref == "query"].copy()
# ----- model -----
scv.model.TOTALVI.prepare_query_anndata(adata,
params.model,
return_reference_var_names=True)
print(adata)
# setup model
vae_q = sca.models.TOTALVI.load_query_data(
adata,
params.model,
inplace_subset_query_vars = True,
freeze_expression=True # freezes neurons in the first layer for TL
)
# # run model
vae_q.train(200, plan_kwargs=dict(weight_decay=0.0))
# ----- query data results ------
### latent space
adata.obsm["X_totalVI"] = vae_q.get_latent_representation()
### umaps
sc.pp.neighbors(adata, use_rep="X_totalVI")
sc.tl.umap(adata, min_dist=0.4)
# plot annot
sc.pl.umap(
adata,
color=[params.annotCol],
frameon=False,
ncols=1,
title = "Query Annotations"
)
plt.savefig(output_fn["umap_annot"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})
# plot samples
sc.pl.umap(
adata,
color=[params.modelSampleCol],
frameon=False,
ncols=1,
title = "Query Samples"
)
plt.savefig(output_fn["umap_samples"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})
# ----- integrated results -----
### latent space
adata_full_new = adata_full.copy()
adata_full_new.obsm["X_totalVI"] = vae_q.get_latent_representation(adata_full_new)
### umaps
sc.pp.neighbors(adata_full_new, use_rep="X_totalVI")
sc.tl.umap(adata_full_new, min_dist=0.3)
# plot query/ref
sc.pl.umap(
adata_full_new,
color=["ref"],
frameon=False,
ncols=1,
title="Reference and query"
)
plt.savefig(output_fn["umap_integrated"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})
# plot annot
sc.pl.umap(
adata_full_new,
color=[params.annotCol],
frameon=False,
ncols=1,
title="Reference and query"
)
plt.savefig(output_fn["umap_integrated_annot"], format = "tiff", dpi=300, bbox_inches="tight", pil_kwargs={"compression": "tiff_lzw"})
### get imputed and/or normalized proteins
# find non-zero reference samples
ref_protein_samples = adata_full_new[adata_full_new.obsm["protein"].sum(axis = 1) != 0].obs[params.modelSampleCol].unique()
_, imp_norm_proteins = vae_q.get_normalized_expression(
adata_full_new,
n_samples=25,
return_mean=True,
transform_batch=ref_protein_samples
)
imp_norm_proteins.columns = proteins
adata_full_new.obsm["impute_or_norm_protein"] = imp_norm_proteins
# ----- save results -----
if 'discard' in adata_full_new.obs.columns:
adata_full_new.obs["discard"] = str(adata.obs["discard"])
adata_full_new.write(output_fn["h5"])
# ----- save model -----
# updated model
vae_q.save(output_fn["model"], overwrite = True)
# updated model adata
adata_full_new.write(output_fn["model_adata"])
# symlink genes
os.symlink(params.modelGenes, output_fn["model_genes"])
# symlink protein
os.symlink(params.modelProteins, output_fn["model_proteins"])
I have been unable to have successful integration across 3 separate datasets with my reference data. Both the query and reference only contain T cells. My reference dataset was derived with TotalVI (i.e. contained RNA + protein) whereas each query is scRNA-seq data only.
Are there any recommendations on what might be the issue?