scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.24k stars 350 forks source link

scvi.model.TOTALVI.train returning only NaN values #873

Closed KyleFerchen closed 3 years ago

KyleFerchen commented 3 years ago

We are having problems when we combine 6 of our CITE-seq datasets. The workflow is fine when we just use 2 of the datasets combined, but when we use all 6 I can never get the training step to work.

At first, the function scvi.model.TOTALVI.train() would just crash every time I started it with all 6 datasets, which I thought was because the adata object was too big, and the function is likely allocating something our HPC job can't provide (>250GB). So I tried different parameters for the training, including reducing the number of epochs as well as adjusting the training set size to 0.8 instead of 0.9. After making those adjustments, the .train() function runs (taking about 3 days), but the resulting object only contains NaN values.

Have you ever experienced this issue with scaling the size of the input?

Do you think this could be an error with how one of the 6 CITE-seq datasets is structured?

What do you think would be the best approach for me to debug this?

"""
This is adapted from : https://www.scvi-tools.org/en/stable/user_guide/notebooks/totalVI.html
"""
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import scvi
import scanpy as sc

import shelve

os.chdir('/data/salomonis2/Grimes/RNA/scRNA-Seq/10x-Genomics/X202SC20092846-Z01-F001/TotalVI')

grid_search_prefix = "_200_epochs_80_percent_training_"

# Find the shared cells between RNA adata object and protein expression pandas dataframe
def filter_to_shared_cells(adata_ob, protein_exp_ob):
  return(pd.Series(adata_ob.obs.index)[pd.Series(adata_ob.obs.index).isin(pd.Series(protein_exp_ob.index))])

# Define a function to load the adata object, attaching the ADT values to the RNA
def load_adata_and_adt_files(path_adata, path_adt):
  adata_temp = sc.read(path_adata,ext='txt').transpose() 
  protein_expression_temp = pd.read_table(path_adt).set_index('UID')
  protein_expression_temp.index.name = None
  shared_cells = filter_to_shared_cells(adata_temp, protein_expression_temp)
  adata_temp = adata_temp[shared_cells,:]
  adata_temp.obsm['protein_expression'] = protein_expression_temp.loc[shared_cells]
  return(adata_temp)

# Find the shared columns, and reassign the dataframe for the ADTs to only those columns, then combine the adata objects
def combine_adata_objects(list_of_adata_objs):
  adts = pd.Series(list_of_adata_objs[0].obsm._data['protein_expression'].columns)
  for x in list_of_adata_objs:
    adts = adts[adts.isin(pd.Series(x.obsm._data['protein_expression'].columns))]
  for x in list_of_adata_objs:
    x.obsm._data['protein_expression'] = x.obsm._data['protein_expression'][adts]
  adata = list_of_adata_objs[0].concatenate(list_of_adata_objs[1:])
  return(adata)

adata_cd127 = load_adata_and_adt_files('CD127rna-filtered.txt', 'CD127_CB-umi_counts-transposed.txt')
adata_hsc = load_adata_and_adt_files('HSCrna-filtered.txt','HSC_CB-umi_counts-transposed-500.txt')
adata_KitplsI = load_adata_and_adt_files('kitplsIrna-filtered.txt','KitI_CB-umi_counts-transposed-500.txt')
adata_KitplsII = load_adata_and_adt_files('kitplsIIrna-filtered.txt','KitII_CB-umi_counts-transposed-500.txt')
adata_MultiLin = load_adata_and_adt_files('MultliLin_rna-filtered.txt','MultiLin_CB-umi_counts-transposed-500.txt')
adata_Thymus = load_adata_and_adt_files('Thymus_rna-filtered.txt','Thymus_CB-umi_counts-transposed-500.txt')

adata = combine_adata_objects([adata_cd127, adata_hsc , adata_KitplsI , adata_KitplsII , adata_MultiLin ,  adata_Thymus])

### add celltype and batch information 
####  add celltype
annofile = pd.read_csv('cellannotation_compatiblewith-totalVI.txt', delimiter="\t", header=0, index_col=0)
#### some barcodes in the adata object might not be present in the cellanno files barcodes, for that case the respective celltype and batch will be considered as undefined

adata.obs['celltype'] = 'undefined'
def clean_cell_names(x):
  return(x.split("-OutliersRemoved-")[0])
adata.obs.index = pd.Series(adata.obs.index).apply(clean_cell_names)
annofile.index = pd.Series(annofile.index).apply(clean_cell_names)
shared_cells = pd.Series(adata.obs.index)[pd.Series(adata.obs.index).isin(pd.Series(annofile.index))]
adata.obs.loc[shared_cells,'celltype'] = annofile.loc[shared_cells,'celltype']

### add batch information
adata.obs['Batch'] = 'undefined'
adata.obs.loc[shared_cells,'Batch'] = annofile.loc[shared_cells,'Batch']

adata.obsm["protein_expression"].index = adata.obs.index

adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata

sc.pp.highly_variable_genes(
    adata, 
    n_top_genes=4000, 
    flavor="seurat_v3",
    batch_key="Batch", 
    subset=True,
    layer="counts"
)

scvi.data.setup_anndata(
    adata, 
    layer="counts", 
    batch_key="Batch", 
    protein_expression_obsm_key="protein_expression"
)

os.chdir("/data/salomonis2/LabFiles/Kyle/Analysis/2020_12_06_cite_seq_total_seq_data_processing/output/all_combined/")
#### Prepare and run model
vae = scvi.model.TOTALVI(adata, use_cuda=True, latent_distribution="normal")
vae.train(n_epochs=200, train_size=0.8)

plt.plot(vae.trainer.history["elbo_test_set"], label="test")
plt.title("Negative ELBO over training epochs")
plt.ylim(1200, 1400)
plt.legend()
plt.savefig("elbo_test_plot" + grid_search_prefix + ".pdf")

#### Analyze outputs
adata.obsm["X_totalVI"] = vae.get_latent_representation()
rna, protein = vae.get_normalized_expression(
    n_samples=25, 
    return_mean=True, 
    transform_batch=["KitPlusI", "KitPlsII", "HSC", "CD127", "MultiL", "Thymus", 'undefined']
)

adata.layers["denoised_rna"], adata.obsm["denoised_protein"] = rna, protein

adata.obsm["protein_foreground_prob"] = vae.get_protein_foreground_probability(
    n_samples=25, 
    return_mean=True, 
    transform_batch=["KitPlusI", "KitPlsII", "HSC", "CD127", "MultiL", "Thymus", 'undefined']
)
parsed_protein_names = [p.split("_")[0] for p in adata.obsm["protein_expression"].columns]
adata.obsm["protein_foreground_prob"].columns = parsed_protein_names

# Compute clusters and visualize the latent space
sc.pp.neighbors(adata, use_rep="X_totalVI")
sc.tl.umap(adata, min_dist=0.4)
sc.tl.leiden(adata, key_added="leiden_totalVI")
sc.pl.umap(
    adata, 
    color=["leiden_totalVI", "Batch"], 
    frameon=False,
    ncols=1,
    save= grid_search_prefix + '_plot_without_totalVI.pdf'
)
/usr/local/python/3.7.1/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
/users/fero3l/.local/lib/python3.7/site-packages/scvi/core/distributions/_negative_binomial.py:519: UserWarning: The value argument must be within the support of the distribution
  UserWarning,
/users/fero3l/.local/lib/python3.7/site-packages/umap/umap_.py:401: UserWarning: Failed to correctly find n_neighbors for some samples.Results may be less than ideal. Try re-running withdifferent parameters.
  "Failed to correctly find n_neighbors for some samples."
Traceback (most recent call last):
  File "totalVI_200_epochs_80_percent_training.py", line 124, in <module>
    sc.tl.umap(adata, min_dist=0.4)
  File "/users/fero3l/.local/lib/python3.7/site-packages/scanpy/tools/_umap.py", line 173, in umap
    verbose=settings.verbosity > 3,
  File "/users/fero3l/.local/lib/python3.7/site-packages/umap/umap_.py", line 1038, in simplicial_set_embedding
    metric_kwds=metric_kwds,
  File "/users/fero3l/.local/lib/python3.7/site-packages/umap/spectral.py", line 281, in spectral_layout
    metric_kwds=metric_kwds,
  File "/users/fero3l/.local/lib/python3.7/site-packages/umap/spectral.py", line 168, in multi_component_layout
    metric_kwds=metric_kwds,
  File "/users/fero3l/.local/lib/python3.7/site-packages/umap/spectral.py", line 97, in component_layout
    component_centroids, metric=metric, **metric_kwds
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 1752, in pairwise_distances
    return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 1348, in _parallel_pairwise
    return func(X, Y, **kwds)
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 262, in euclidean_distances
    X, Y = check_pairwise_arrays(X, Y)
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 137, in check_pairwise_arrays
    estimator=estimator)
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/utils/validation.py", line 578, in check_array
    allow_nan=force_all_finite == 'allow-nan')
  File "/users/fero3l/.local/lib/python3.7/site-packages/sklearn/utils/validation.py", line 60, in _assert_all_finite
    msg_dtype if msg_dtype is not None else X.dtype)
ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

Exited with exit code 1.

Resource usage summary:

CPU time   :1297609.50 sec.
Max Memory :     49726 MB
Max Swap   :     71028 MB

Max Processes  :         5
Max Threads    :        38

The output (if any) follows:

INFO  Using batches from adata.obs["Batch"]
INFO  No label_key inputted, assuming all cells have same label
INFO  Using data from adata.layers["counts"]
INFO  Computing library size prior per batch
INFO  Using protein expression from adata.obsm['protein_expression']
INFO  Using protein names from columns of adata.obsm['protein_expression']
INFO  Found batches with missing protein expression
INFO  Successfully registered anndata object containing 63041 cells, 4000
vars, 7 batches, 1 labels, and 195 proteins. Also registered 0 extra
categorical covariates and 0 extra continuous covariates.
INFO  Please do not further modify adata until model is trained.
INFO  Training for 200 epochs.
INFO  KL warmup for 47280.75 iterations

Training...: 0%| | 0/200 [00:00<?, ?it/s] Training...: 0%| | 1/200 [10:13<33:55:12, 613.63s/it] Training...: 1%| | 2/200 [20:18<33:36:39, 611.11s/it] Training...: 2%|▏ | 3/200 [30:29<33:25:59, 610.96s/it] Training...: 2%|▏ | 4/200 [40:37<33:12:26, 609.93s/it] Training...: 2%|▎ | 5/200 [50:50<33:05:19, 610.87s/it] Training...: 3%|▎ | 6/200 [1:02:25<34:16:51, 636.14s/it] Training...: 4%|▎ | 7/200 [1:12:47<33:52:56, 632.00s/it] Training...: 4%|▍ | 8/200 [1:23:01<33:25:19, 626.66s/it] Training...: 4%|▍ | 9/200 [1:33:21<33:08:25, 624.64s/it] Training...: 5%|▌ | 10/200 [1:43:18<32:31:15, 616.18s/it] Training...: 6%|▌ | 11/200 [1:53:37<32:24:02, 617.16s/it] Training...: 6%|▌ | 12/200 [2:03:47<32:06:58, 614.99s/it] Training...: 6%|▋ | 13/200 [2:13:52<31:47:16, 611.96s/it] Training...: 7%|▋ | 14/200 [2:24:06<31:39:20, 612.69s/it] Training...: 8%|▊ | 15/200 [2:34:34<31:42:43, 617.10s/it] Training...: 8%|▊ | 16/200 [2:45:15<31:55:05, 624.48s/it] Training...: 8%|▊ | 17/200 [2:55:47<31:51:06, 626.59s/it] Training...: 9%|▉ | 18/200 [3:06:19<31:45:44, 628.27s/it] Training...: 10%|▉ | 19/200 [3:17:21<32:05:42, 638.36s/it] Training...: 10%|█ | 20/200 [3:28:34<32:26:34, 648.86s/it] Training...: 10%|█ | 21/200 [3:39:09<32:02:56, 644.56s/it] Training...: 11%|█ | 22/200 [3:50:31<32:26:06, 655.99s/it] Training...: 12%|█▏ | 23/200 [4:01:08<31:57:42, 650.07s/it] Training...: 12%|█▏ | 24/200 [4:12:57<32:38:48, 667.77s/it] Training...: 12%|█▎ | 25/200 [4:23:52<32:17:05, 664.15s/it] Training...: 13%|█▎ | 26/200 [4:36:40<33:35:40, 695.06s/it] Training...: 14%|█▎ | 27/200 [5:00:32<44:02:03, 916.32s/it] Training...: 14%|█▍ | 28/200 [5:25:30<52:07:00, 1090.82s/it] Training...: 14%|█▍ | 29/200 [5:37:11<46:15:28, 973.85s/it] Training...: 15%|█▌ | 30/200 [5:47:38<41:04:22, 869.78s/it] Training...: 16%|█▌ | 31/200 [5:58:11<37:29:33, 798.66s/it] Training...: 16%|█▌ | 32/200 [6:08:37<34:51:30, 746.97s/it] Training...: 16%|█▋ | 33/200 [6:19:02<32:56:48, 710.23s/it] Training...: 17%|█▋ | 34/200 [6:29:28<31:34:53, 684.90s/it] Training...: 18%|█▊ | 35/200 [6:39:53<30:34:46, 667.19s/it] Training...: 18%|█▊ | 36/200 [6:50:21<29:50:58, 655.24s/it] Training...: 18%|█▊ | 37/200 [7:00:42<29:12:31, 645.10s/it] Training...: 19%|█▉ | 38/200 [7:11:06<28:44:34, 638.73s/it] Training...: 20%|█▉ | 39/200 [7:21:37<28:27:43, 636.42s/it] Training...: 20%|██ | 40/200 [7:32:06<28:11:06, 634.17s/it] Training...: 20%|██ | 41/200 [7:42:30<27:52:33, 631.15s/it] Training...: 21%|██ | 42/200 [7:52:57<27:38:47, 629.92s/it] Training...: 22%|██▏ | 43/200 [8:03:23<27:25:00, 628.67s/it] Training...: 22%|██▏ | 44/200 [8:13:49<27:12:10, 627.76s/it] Training...: 22%|██▎ | 45/200 [8:24:18<27:02:47, 628.18s/it] Training...: 23%|██▎ | 46/200 [8:34:45<26:51:20, 627.80s/it] Training...: 24%|██▎ | 47/200 [8:45:10<26:38:53, 627.02s/it] Training...: 24%|██▍ | 48/200 [8:55:37<26:28:32, 627.06s/it] Training...: 24%|██▍ | 49/200 [9:06:07<26:20:14, 627.91s/it] Training...: 25%|██▌ | 50/200 [9:16:42<26:14:52, 629.95s/it] Training...: 26%|██▌ | 51/200 [9:27:12<26:04:51, 630.14s/it] Training...: 26%|██▌ | 52/200 [9:37:46<25:56:49, 631.15s/it] Training...: 26%|██▋ | 53/200 [9:48:17<25:46:19, 631.16s/it] Training...: 27%|██▋ | 54/200 [9:58:48<25:35:44, 631.13s/it] Training...: 28%|██▊ | 55/200 [10:09:13<25:21:00, 629.38s/it] Training...: 28%|██▊ | 56/200 [10:19:52<25:17:10, 632.16s/it] Training...: 28%|██▊ | 57/200 [10:30:15<24:59:56, 629.35s/it] Training...: 29%|██▉ | 58/200 [10:40:39<24:45:57, 627.87s/it] Training...: 30%|██▉ | 59/200 [10:51:07<24:35:46, 627.99s/it] Training...: 30%|███ | 60/200 [11:01:37<24:26:15, 628.40s/it] Training...: 30%|███ | 61/200 [11:12:06<24:16:33, 628.73s/it] Training...: 31%|███ | 62/200 [11:22:36<24:06:53, 629.08s/it] Training...: 32%|███▏ | 63/200 [11:33:41<24:21:00, 639.86s/it] Training...: 32%|███▏ | 64/200 [11:45:44<25:06:51, 664.79s/it] Training...: 32%|███▎ | 65/200 [11:58:08<25:49:17, 688.57s/it] Training...: 33%|███▎ | 66/200 [12:10:56<26:31:04, 712.42s/it] Training...: 34%|███▎ | 67/200 [12:34:16<33:56:14, 918.60s/it] Training...: 34%|███▍ | 68/200 [13:05:38<44:16:40, 1207.58s/it] Training...: 34%|███▍ | 69/200 [13:36:12<50:46:43, 1395.45s/it] Training...: 35%|███▌ | 70/200 [14:08:14<56:05:48, 1553.45s/it] Training...: 36%|███▌ | 71/200 [14:40:12<59:35:20, 1662.95s/it] Training...: 36%|███▌ | 72/200 [15:12:48<62:15:14, 1750.90s/it] Training...: 36%|███▋ | 73/200 [15:48:22<65:49:14, 1865.78s/it] Training...: 37%|███▋ | 74/200 [16:28:12<70:48:25, 2023.06s/it] Training...: 38%|███▊ | 75/200 [17:07:12<73:32:29, 2118.00s/it] Training...: 38%|███▊ | 76/200 [17:47:12<75:52:18, 2202.73s/it] Training...: 38%|███▊ | 77/200 [18:27:42<77:35:02, 2270.75s/it] Training...: 39%|███▉ | 78/200 [19:07:08<77:55:48, 2299.57s/it] Training...: 40%|███▉ | 79/200 [19:51:44<81:04:52, 2412.33s/it] Training...: 40%|████ | 80/200 [20:32:48<80:55:31, 2427.76s/it] Training...: 40%|████ | 81/200 [21:11:38<79:17:27, 2398.71s/it] Training...: 41%|████ | 82/200 [21:40:24<72:00:13, 2196.72s/it] Training...: 42%|████▏ | 83/200 [22:08:09<66:12:31, 2037.20s/it] Training...: 42%|████▏ | 84/200 [22:38:21<63:28:18, 1969.81s/it] Training...: 42%|████▎ | 85/200 [22:58:55<55:52:02, 1748.89s/it] Training...: 43%|████▎ | 86/200 [23:13:04<46:50:09, 1479.03s/it] Training...: 44%|████▎ | 87/200 [23:27:25<40:36:26, 1293.69s/it] Training...: 44%|████▍ | 88/200 [23:41:41<36:09:27, 1162.21s/it] Training...: 44%|████▍ | 89/200 [23:55:50<32:56:06, 1068.17s/it] Training...: 45%|████▌ | 90/200 [24:10:08<30:42:44, 1005.13s/it] Training...: 46%|████▌ | 91/200 [24:24:23<29:04:13, 960.13s/it] Training...: 46%|████▌ | 92/200 [24:38:41<27:53:06, 929.51s/it] Training...: 46%|████▋ | 93/200 [24:52:55<26:57:07, 906.80s/it] Training...: 47%|████▋ | 94/200 [25:06:56<26:07:28, 887.25s/it] Training...: 48%|████▊ | 95/200 [25:21:04<25:31:47, 875.31s/it] Training...: 48%|████▊ | 96/200 [25:35:11<25:02:49, 867.01s/it] Training...: 48%|████▊ | 97/200 [26:01:16<30:47:30, 1076.22s/it] Training...: 49%|████▉ | 98/200 [26:27:48<34:52:33, 1230.92s/it] Training...: 50%|████▉ | 99/200 [26:54:03<37:25:45, 1334.11s/it] Training...: 50%|█████ | 100/200 [27:20:37<39:13:35, 1412.15s/it] Training...: 50%|█████ | 101/200 [27:47:02<40:15:52, 1464.17s/it] Training...: 51%|█████ | 102/200 [28:13:29<40:51:19, 1500.81s/it] Training...: 52%|█████▏ | 103/200 [28:36:06<39:16:42, 1457.76s/it] Training...: 52%|█████▏ | 104/200 [28:50:10<33:57:46, 1273.61s/it] Training...: 52%|█████▎ | 105/200 [29:07:01<31:31:51, 1194.85s/it] Training...: 53%|█████▎ | 106/200 [29:23:51<29:44:56, 1139.33s/it] Training...: 54%|█████▎ | 107/200 [29:41:02<28:35:55, 1107.05s/it] Training...: 54%|█████▍ | 108/200 [29:58:01<27:36:43, 1080.47s/it] Training...: 55%|█████▍ | 109/200 [30:15:23<27:01:19, 1069.00s/it] Training...: 55%|█████▌ | 110/200 [30:33:03<26:39:14, 1066.17s/it] Training...: 56%|█████▌ | 111/200 [30:50:46<26:20:01, 1065.18s/it] Training...: 56%|█████▌ | 112/200 [31:08:27<26:00:34, 1064.03s/it] Training...: 56%|█████▋ | 113/200 [31:26:04<25:39:42, 1061.86s/it] Training...: 57%|█████▋ | 114/200 [31:43:50<25:23:58, 1063.24s/it] Training...: 57%|█████▊ | 115/200 [32:01:20<25:00:40, 1059.30s/it] Training...: 58%|█████▊ | 116/200 [32:15:45<23:21:13, 1000.88s/it] Training...: 58%|█████▊ | 117/200 [32:30:08<22:07:31, 959.66s/it] Training...: 59%|█████▉ | 118/200 [32:44:12<21:04:08, 924.98s/it] Training...: 60%|█████▉ | 119/200 [32:58:14<20:15:07, 900.10s/it] Training...: 60%|██████ | 120/200 [33:20:46<23:00:49, 1035.62s/it] Training...: 60%|██████ | 121/200 [33:47:25<26:25:59, 1204.56s/it] Training...: 61%|██████ | 122/200 [34:14:08<28:41:08, 1323.95s/it] Training...: 62%|██████▏ | 123/200 [34:40:59<30:09:36, 1410.08s/it] Training...: 62%|██████▏ | 124/200 [35:07:42<30:59:36, 1468.12s/it] Training...: 62%|██████▎ | 125/200 [35:34:23<31:25:03, 1508.05s/it] Training...: 63%|██████▎ | 126/200 [36:01:11<31:36:54, 1538.03s/it] Training...: 64%|██████▎ | 127/200 [36:26:49<31:11:04, 1537.87s/it] Training...: 64%|██████▍ | 128/200 [36:44:05<27:44:44, 1387.28s/it] Training...: 64%|██████▍ | 129/200 [37:01:26<25:18:40, 1283.39s/it] Training...: 65%|██████▌ | 130/200 [37:18:48<23:32:45, 1210.93s/it] Training...: 66%|██████▌ | 131/200 [37:36:13<22:15:24, 1161.22s/it] Training...: 66%|██████▌ | 132/200 [37:57:01<22:25:42, 1187.39s/it] Training...: 66%|██████▋ | 133/200 [38:14:02<21:10:08, 1137.44s/it] Training...: 67%|██████▋ | 134/200 [38:37:28<22:19:42, 1217.92s/it] Training...: 68%|██████▊ | 135/200 [38:55:02<21:06:03, 1168.68s/it] Training...: 68%|██████▊ | 136/200 [39:11:31<19:49:08, 1114.82s/it] Training...: 68%|██████▊ | 137/200 [39:27:16<18:37:14, 1064.04s/it] Training...: 69%|██████▉ | 138/200 [39:43:53<17:58:34, 1043.79s/it] Training...: 70%|██████▉ | 139/200 [40:00:27<17:26:06, 1028.97s/it] Training...: 70%|███████ | 140/200 [40:17:05<16:59:28, 1019.47s/it] Training...: 70%|███████ | 141/200 [40:33:45<16:36:50, 1013.74s/it] Training...: 71%|███████ | 142/200 [40:50:24<16:15:46, 1009.42s/it] Training...: 72%|███████▏ | 143/200 [41:07:01<15:55:12, 1005.49s/it] Training...: 72%|███████▏ | 144/200 [41:23:38<15:36:17, 1003.18s/it] Training...: 72%|███████▎ | 145/200 [41:36:00<14:07:37, 924.68s/it] Training...: 73%|███████▎ | 146/200 [41:45:31<12:16:51, 818.72s/it] Training...: 74%|███████▎ | 147/200 [41:55:46<11:09:14, 757.64s/it] Training...: 74%|███████▍ | 148/200 [42:06:10<10:21:41, 717.33s/it] Training...: 74%|███████▍ | 149/200 [42:16:26<9:44:00, 687.08s/it] Training...: 75%|███████▌ | 150/200 [42:26:42<9:14:41, 665.63s/it] Training...: 76%|███████▌ | 151/200 [42:36:59<8:51:49, 651.22s/it] Training...: 76%|███████▌ | 152/200 [42:47:12<8:31:40, 639.59s/it] Training...: 76%|███████▋ | 153/200 [42:57:28<8:15:27, 632.50s/it] Training...: 77%|███████▋ | 154/200 [43:07:37<7:59:38, 625.62s/it] Training...: 78%|███████▊ | 155/200 [43:17:54<7:47:13, 622.97s/it] Training...: 78%|███████▊ | 156/200 [43:28:13<7:35:57, 621.77s/it] Training...: 78%|███████▊ | 157/200 [43:38:29<7:24:25, 620.14s/it] Training...: 79%|███████▉ | 158/200 [43:48:46<7:13:16, 618.97s/it] Training...: 80%|███████▉ | 159/200 [43:59:04<7:02:51, 618.82s/it] Training...: 80%|████████ | 160/200 [44:12:12<7:26:19, 669.48s/it] Training...: 80%|████████ | 161/200 [44:25:37<7:41:36, 710.18s/it] Training...: 81%|████████ | 162/200 [44:39:06<7:48:33, 739.82s/it] Training...: 82%|████████▏ | 163/200 [44:49:21<7:13:12, 702.50s/it] Training...: 82%|████████▏ | 164/200 [44:59:34<6:45:20, 675.58s/it] Training...: 82%|████████▎ | 165/200 [45:09:49<6:23:22, 657.20s/it] Training...: 83%|████████▎ | 166/200 [45:20:08<6:06:02, 645.96s/it] Training...: 84%|████████▎ | 167/200 [45:30:28<5:50:52, 637.97s/it] Training...: 84%|████████▍ | 168/200 [45:40:44<5:36:49, 631.53s/it] Training...: 84%|████████▍ | 169/200 [45:51:02<5:24:13, 627.52s/it] Training...: 85%|████████▌ | 170/200 [46:01:19<5:12:12, 624.41s/it] Training...: 86%|████████▌ | 171/200 [46:11:32<5:00:03, 620.83s/it] Training...: 86%|████████▌ | 172/200 [46:21:37<4:47:28, 616.02s/it] Training...: 86%|████████▋ | 173/200 [46:31:49<4:36:45, 615.00s/it] Training...: 87%|████████▋ | 174/200 [46:42:05<4:26:35, 615.20s/it] Training...: 88%|████████▊ | 175/200 [46:52:22<4:16:32, 615.70s/it] Training...: 88%|████████▊ | 176/200 [47:02:39<4:06:26, 616.10s/it] Training...: 88%|████████▊ | 177/200 [47:12:57<3:56:26, 616.80s/it] Training...: 89%|████████▉ | 178/200 [47:23:17<3:46:28, 617.67s/it] Training...: 90%|████████▉ | 179/200 [47:33:35<3:36:13, 617.79s/it] Training...: 90%|█████████ | 180/200 [47:43:52<3:25:52, 617.64s/it] Training...: 90%|█████████ | 181/200 [47:54:12<3:15:46, 618.24s/it] Training...: 91%|█████████ | 182/200 [48:04:31<3:05:34, 618.56s/it] Training...: 92%|█████████▏| 183/200 [48:14:47<2:54:59, 617.62s/it] Training...: 92%|█████████▏| 184/200 [48:25:05<2:44:43, 617.71s/it] Training...: 92%|█████████▎| 185/200 [48:35:20<2:34:13, 616.93s/it] Training...: 93%|█████████▎| 186/200 [48:45:38<2:24:02, 617.31s/it] Training...: 94%|█████████▎| 187/200 [48:55:55<2:13:45, 617.36s/it] Training...: 94%|█████████▍| 188/200 [49:06:12<2:03:25, 617.10s/it] Training...: 94%|█████████▍| 189/200 [49:16:30<1:53:11, 617.39s/it] Training...: 95%|█████████▌| 190/200 [49:26:46<1:42:49, 616.95s/it] Training...: 96%|█████████▌| 191/200 [49:37:00<1:32:25, 616.19s/it] Training...: 96%|█████████▌| 192/200 [49:47:17<1:22:11, 616.38s/it] Training...: 96%|█████████▋| 193/200 [49:57:33<1:11:52, 616.08s/it] Training...: 97%|█████████▋| 194/200 [50:07:51<1:01:41, 616.89s/it] Training...: 98%|█████████▊| 195/200 [50:18:08<51:23, 616.72s/it]
Training...: 98%|█████████▊| 196/200 [50:28:26<41:09, 617.30s/it] Training...: 98%|█████████▊| 197/200 [50:38:44<30:52, 617.49s/it] Training...: 99%|█████████▉| 198/200 [50:49:03<20:35, 617.78s/it] Training...: 100%|█████████▉| 199/200 [50:59:20<10:17, 617.79s/it] Training...: 100%|██████████| 200/200 [51:09:39<00:00, 618.06s/it] Training...: 100%|██████████| 200/200 [51:09:39<00:00, 920.90s/it] INFO  Training is still in warming up phase. If your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup.
INFO  Training time: 161221 s. / 200 epochs

adamgayoso commented 3 years ago

Hi Kyle,

I've seen this sort of issue before -- I think it has something to do with the latent representations being on the simplex and values near 0. You might try adding latent_distribution="normal" to scvi.model.TOTALVI or use the metric="correlation" for example in the call to scanpy's neighbors method.

adamgayoso commented 3 years ago

I also wanted to add that it would take less than 1hr to run totalVI with this many cells on a GPU, and this dataset could be run on Google Colab for free.

Could you confirm that adata.obsm["X_totalVI"] actually contains NaN values?

Also the fact that you get this warning:

/users/fero3l/.local/lib/python3.7/site-packages/scvi/core/distributions/_negative_binomial.py:519: UserWarning: The value argument must be within the support of the distribution
  UserWarning,

implies that your data is not count data.

KyleFerchen commented 3 years ago

Thank you, I will first try to get my GPU set up to run it, then try to reproduce the error, confirm the NaN values, and try the suggestions.

KyleFerchen commented 3 years ago

I set up PyTorch to work with my GTX 970 GPU. It trained much faster using the 200 epochs and 0.8 training size I used before:

vae.train(n_epochs=200, train_size=0.8)

INFO     Training for 200 epochs.                                                             
/home/kyle/anaconda3/envs/scvi-env/lib/python3.7/site-packages/scvi/core/distributions/_negative_binomial.py:532: UserWarning: The value argument must be within the support of the distribution
  UserWarning,
INFO     KL warmup for 47280.75 iterations                                                    
Training...: 100%|██████████████████████████████████████████| 200/200 [26:33<00:00,  7.97s/it]
INFO     Training is still in warming up phase. If your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup.                        
INFO     Training time:  1475 s. / 200 epochs 

I then confirmed that the .get_latent_representation() is just returning nan values:

vae.get_latent_representation()

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

I was going to try to set "latent_distribution" to 'normal', but it seems this is already the default, as you can see when I print out the vae object: vae

TotalVI Model with the following params: 
n_latent: 20, gene_dispersion: gene, protein_dispersion: protein, gene_likelihood: nb, latent_distribution: normal
Training status: Trained

To print summary of associated AnnData, use: scvi.data.view_anndata_setup(model.adata)

I'll just try to step through the code to try to find the problem I guess.

adamgayoso commented 3 years ago

A couple of things to look into:

  1. Found batches with missing protein expression -- is it actually the case that some of your "batches" have non-identical protein features?
  2. /home/kyle/anaconda3/envs/scvi-env/lib/python3.7/site-packages/scvi/core/distributions/_negative_binomial.py:532: UserWarning: The value argument must be within the support of the distribution UserWarning, -- you shouldn't be getting this warning unless there is a value in your adata that is continuous valued or np.nan for example, maybe this has to do with point (1)?
  3. Maybe first try the case where you ensure each batch has the same exact proteins?
  4. Is vae.trainer.history["elbo_test_set"] nan?
  5. You can test all of these things with e.g., 20 epochs to iterate more quickly and identify the issue.

It's perplexing to me that you'd get np.nan after training because it should lead to nan loss which would stop training. 80% or 90% training size shouldn't make a difference for whether the method works or not.

Also @kyleferchen -- I'm happy to try it out myself if you send me the input anndata file.

KyleFerchen commented 3 years ago

I found my error! Just one of the cells in one of the datasets had some NaN values for protein expression, which I guess messed with everything else.

I just added a step to remove any cells that have missing values: select_cells = pd.Series(adata.obsm._data['protein_expression'].index)[list(adata.obsm._data['protein_expression'].isna().sum(axis=1) == 0)] adata_unfiltered = adata adata = adata[select_cells,:]

Suprisingly, I still get the Found batches with missing protein expression. I guess maybe some batches don't have any ADT counts for specific antibodies, but that didn't interfere with the training. It worked after removing NaN data values from the AnnData objects protein_expression table.

adamgayoso commented 3 years ago

I guess maybe some batches don't have any ADT counts for specific antibodies, but that didn't interfere with the training. It worked after removing NaN data values from the AnnData objects protein_expression table.

We designed totalVI to handle missing proteins, but you should be sure this is what you want, and to take care if you look at the denoised protein expression values.