theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
350 stars 46 forks source link

Proper use of the "Lineages" Parameter for Lineage Drivers #1182

Open jwalewski opened 7 months ago

jwalewski commented 7 months ago

Hello,

I am attempting to run compute_lineage_drivers(), but no matter what value I set for "lineages" I keep getting a value error.:

  File "<string>", line 421, in combined_kernel_analysis
  File "/home/jw2894/.local/lib/python3.10/site-packages/cellrank/estimators/mixins/_lineage_drivers.py", line 137, in compute_lineage_drivers
    _ = fate_probs[lineages]
  File "/home/jw2894/.local/lib/python3.10/site-packages/cellrank/_utils/_lineage.py", line 305, in __getitem__
    obj = self.__getitem(item)
  File "/home/jw2894/.local/lib/python3.10/site-packages/cellrank/_utils/_lineage.py", line 424, in __getitem
    obj = super().__getitem__(item)
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

The tutorial has it equal to lineages=["Delta"], for example, but when I've tried an array of strings containing the names of states of my cell population, I still get the same value error. I've also tried the slice [:] on my entire array, and other ways of setting the indices.

What's wrong with the way I'm doing it? Any insight is appreciated.

WeilerP commented 7 months ago

@jwalewski please provide a code snippet to understand the exact workflow and the versions of the Python packages.

jwalewski commented 7 months ago

The exact code snippet would be:

drivers_per_cluster= current_model.compute_lineage_drivers()

Package versions:

cellrank==2.0.0 scanpy==1.9.4 anndata==0.11.0.dev31+g49ca3bd numpy==1.24.4 numba==0.57.1 scipy==1.11.2 pandas==2.1.0 pygpcca==1.0.4 scikit-learn==1.3.0 statsmodels==0.14.0 python-igraph==0.10.8 scvelo==0.2.5 pygam==0.8.0 matplotlib==3.7.1 seaborn==0.12.2
WeilerP commented 7 months ago

Thanks, @jwalewski. We need a complete code snippet of the workflow, i.e., especially the CellRank part; ideally, you can also share how you processed the data. Could you please also update CellRank to the latest version? And is there a reason why you use a developer installation of AnnData?

jwalewski commented 7 months ago

Understood.

I'll address the versions questions first as the code block is long. So, I attempted to create a new conda enviornment with the most up to date version of cellrank (and ONLY that), however, I got an error with scanpy:

(CellRank240328)[jw2894@r813u09n09.mccleary ~]$ python3.11 -c 'import cellrank as cr; cr.logging.print_versions()' 
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/__init__.py", line 3, in <module>
    from cellrank import datasets, estimators, kernels, logging, models, pl
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/pl/__init__.py", line 2, in <module>
    from cellrank.pl._circular_projection import circular_projection
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/pl/_circular_projection.py", line 17, in <module>
    from scanpy._utils import deprecated_arg_names
ImportError: cannot import name 'deprecated_arg_names' from 'scanpy._utils' (/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/scanpy/_utils/__init__.py)

So naturally, I tried upgrading scanpy and then conda wanted me to downgrade CellRank back to 2.0.0 - any idea of what's going on here?

  (CellRank240328)[jw2894@r813u09n09.mccleary ~]$ conda install scanpy
Retrieving notices: ...working... done
Channels:
 - conda-forge
 - bioconda
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done
# All requested packages already installed.
(CellRank240328)[jw2894@r813u09n09.mccleary ~]$ conda update -n CellRank240328  scanpy
Channels:
 - conda-forge
 - bioconda
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done
## Package Plan ##
  environment location: /home/jw2894/.conda/envs/CellRank240328
  added / updated specs:
    - scanpy
The following packages will be downloaded:
    package                    |            build
    ---------------------------|-----------------
    blas-1.1                   |         openblas           1 KB  conda-forge
    comm-0.2.2                 |     pyhd8ed1ab_0          12 KB  conda-forge
    croniter-2.0.3             |     pyhd8ed1ab_0          37 KB  conda-forge
    dnspython-2.6.1            |     pyhd8ed1ab_1         165 KB  conda-forge
    hyperopt-0.1.2             |             py_0          85 KB  conda-forge
    intel-openmp-2023.1.0      |   hdb19cb5_46306        17.2 MB
    ipython-8.22.2             |     pyh707e725_0         580 KB  conda-forge
    ipywidgets-8.1.2           |     pyhd8ed1ab_0         111 KB  conda-forge
    jupyterlab_widgets-3.0.10  |     pyhd8ed1ab_0         183 KB  conda-forge
    libhwloc-2.10.0            |default_h2fb2949_1000         2.3 MB  conda-forge
    lightning-2.2.1            |     pyhd8ed1ab_0         1.3 MB  conda-forge
    mkl-2023.1.0               |   h213fc3f_46344       171.5 MB
    nomkl-3.0                  |                0          46 KB
    openblas-0.3.26            |pthreads_h7a3da1a_0         5.5 MB  conda-forge
    prompt-toolkit-3.0.42      |     pyha770c72_0         264 KB  conda-forge
    pydantic-2.6.4             |     pyhd8ed1ab_0         265 KB  conda-forge
    pydantic-core-2.16.3       |  py311h46250e7_0         1.6 MB  conda-forge
    pymongo-4.6.3              |  py311hb755f60_0         1.7 MB  conda-forge
    pytorch-2.1.2              |cpu_generic_py311h1584bb0_3        28.0 MB  conda-forge
    scikit-learn-1.4.1.post1   |  py311hc009520_0         9.9 MB  conda-forge
    scvi-tools-0.11.0          |     pyhdfd78af_0         123 KB  bioconda
    starsessions-2.1.3         |     pyhd8ed1ab_0          18 KB  conda-forge
    tbb-2021.7.0               |       h924138e_0         2.0 MB  conda-forge
    widgetsnbextension-4.0.10  |     pyhd8ed1ab_0         866 KB  conda-forge
    ------------------------------------------------------------
                                           Total:       243.6 MB

The following NEW packages will be INSTALLED:

  asttokens          conda-forge/noarch::asttokens-2.4.1-pyhd8ed1ab_0 
  blas               conda-forge/linux-64::blas-1.1-openblas 
  comm               conda-forge/noarch::comm-0.2.2-pyhd8ed1ab_0 
  decorator          conda-forge/noarch::decorator-5.1.1-pyhd8ed1ab_0 
  dnspython          conda-forge/noarch::dnspython-2.6.1-pyhd8ed1ab_1 
  executing          conda-forge/noarch::executing-2.0.1-pyhd8ed1ab_0 
  hyperopt           conda-forge/noarch::hyperopt-0.1.2-py_0 
  intel-openmp       pkgs/main/linux-64::intel-openmp-2023.1.0-hdb19cb5_46306 
  ipython            conda-forge/noarch::ipython-8.22.2-pyh707e725_0 
  ipywidgets         conda-forge/noarch::ipywidgets-8.1.2-pyhd8ed1ab_0 
  jedi               conda-forge/noarch::jedi-0.19.1-pyhd8ed1ab_0 
  jupyterlab_widgets conda-forge/noarch::jupyterlab_widgets-3.0.10-pyhd8ed1ab_0 
  libgomp            conda-forge/linux-64::libgomp-13.2.0-h807b86a_5 
  matplotlib-inline  conda-forge/noarch::matplotlib-inline-0.1.6-pyhd8ed1ab_0 
  nomkl              pkgs/main/linux-64::nomkl-3.0-0 
  openblas           conda-forge/linux-64::openblas-0.3.26-pthreads_h7a3da1a_0 
  parso              conda-forge/noarch::parso-0.8.3-pyhd8ed1ab_0 
  pickleshare        conda-forge/noarch::pickleshare-0.7.5-py_1003 
  prompt-toolkit     conda-forge/noarch::prompt-toolkit-3.0.42-pyha770c72_0 
  pure_eval          conda-forge/noarch::pure_eval-0.2.2-pyhd8ed1ab_0 
  pymongo            conda-forge/linux-64::pymongo-4.6.3-py311hb755f60_0 
  stack_data         conda-forge/noarch::stack_data-0.6.2-pyhd8ed1ab_0 
  toml               conda-forge/noarch::toml-0.10.2-pyhd8ed1ab_0 
  widgetsnbextension conda-forge/noarch::widgetsnbextension-4.0.10-pyhd8ed1ab_0 

The following packages will be UPDATED:

  _openmp_mutex                              4.5-2_kmp_llvm --> 4.5-2_gnu 
  croniter                               1.4.1-pyhd8ed1ab_0 --> 2.0.3-pyhd8ed1ab_0 
  libhwloc                      2.9.3-default_h554bfaf_1009 --> 2.10.0-default_h2fb2949_1000 
  lightning                        2.0.9.post0-pyhd8ed1ab_0 --> 2.2.1-pyhd8ed1ab_0 
  pydantic                               2.1.1-pyhd8ed1ab_0 --> 2.6.4-pyhd8ed1ab_0 
  pydantic-core                       2.4.0-py311h46250e7_0 --> 2.16.3-py311h46250e7_0 
  scikit-learn                        1.1.3-py311h3b52e38_1 --> 1.4.1.post1-py311hc009520_0 
  starsessions                           1.3.0-pyhd8ed1ab_0 --> 2.1.3-pyhd8ed1ab_0 

The following packages will be SUPERSEDED by a higher-priority channel:

  mkl                conda-forge::mkl-2023.2.0-h84fe81f_50~ --> pkgs/main::mkl-2023.1.0-h213fc3f_46344 
  scvelo             conda-forge::scvelo-0.3.2-pyhd8ed1ab_1 --> bioconda::scvelo-0.2.5-pyhdfd78af_0 
  scvi-tools         conda-forge::scvi-tools-1.1.2-pyhd8ed~ --> bioconda::scvi-tools-0.11.0-pyhdfd78af_0 

The following packages will be DOWNGRADED:

  cellrank                               2.0.3-pyhd8ed1ab_0 --> 2.0.0-pyhd8ed1ab_0 
  libtorch                       2.1.2-cpu_mkl_hcefb67d_103 --> 2.1.2-cpu_generic_ha017de0_3 
  pytorch                   2.1.2-cpu_mkl_py311hc5c8824_103 --> 2.1.2-cpu_generic_py311h1584bb0_3 
  suitesparse                             5.10.1-h5a4f163_3 --> 5.10.1-h9e50725_1 
  tbb                                  2021.11.0-h00ab1b0_1 --> 2021.7.0-h924138e_0 

Proceed ([y]/n)? n

CondaSystemExit: Exiting.
(CellRank240328)[jw2894@r813u09n09.mccleary ~]$ 

Here's the entire code block for my CellRank analysis- and the portion about CellRank's combined kernels are at the end.

conda init bash

#first portion: finding all loom files
VELOCYTOPATH_BMET="/gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Input/10X_NSCLC_BM/CellRanger/v7.0.1" #Note: the files are a few folders further down but the paths diverge
VELOCYTOPATH_GBM="/gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Input/10X_GBM" 
bmet_loom_files=$(find $VELOCYTOPATH_BMET -name "*.loom") #Finds all loom files in these folders
gbm_loom_files=$(find $VELOCYTOPATH_GBM -name "*.loom")
#for now it seems as if there's only one seurat object
seurat_object_file="/gpfs/gibbs/pi/hafler/jw2894/Data/RNA_Velocity_Pipeline/H5AD_Files/Brain_tumors_Tcells_harmonized_TUMOR_CD4.h5ad"
#This list then needs to be passed into python as an arugment

#echo "These are the bmet loom files: "$bmet_loom_files
#echo "These are the GBM loom files: "$gbm_loom_files

conda activate CellRank240328
python3.11 -c '
import sys
import numpy as np
import os
from scipy import io
from scipy.sparse import coo_matrix, csr_matrix
import pandas as pd
import cellrank as cr
from cellrank.estimators import GPCCA
import scvelo as scv
import scanpy as sc
import warnings
import re
import gdown
import anndata as ad
from anndata.experimental.multi_files import AnnCollection
from sklearn.model_selection import train_test_split
from sklearn.decomposition import IncrementalPCA
#For palantir
import matplotlib
import matplotlib.pyplot as plt

stochastic_or_dynamical="dynamical"
date="24_03_28"
outputfilepath="/gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Output/"
plotoutputfilepath=outputfilepath+"Plots/"
csvoutputfilepath=outputfilepath+"CSVs/"
objectoutputfilepath=outputfilepath+"Objects/"

import warnings
warnings.simplefilter("ignore", category=UserWarning)
cr.logging.print_versions() #comment on and off as necessary

def readdata(loom_input_data_list, seurat_file_path):
    #read in velocyto data (count matrices) and create a list of anndata objects
    adata_loom_list=[]
    #read in loom files and create sequential anndatas
    #create a new list to store all adata objects that we instantiate
    for input_type in loom_input_data_list[:]: #will do all elements in indices 2 and then 3
        #read each file in and convert to adata
        # print("this is input_type: ", input_type)
        # print("this is input_types type: ", type(input_type))
        input_type = input_type.split()
        # print("this is input_type: ", input_type)
        # print("this is input_types type: ", type(input_type))
        for file_path in input_type:
            #print("this is file_path: ", file_path)
            adata_velocyto = scv.read(file_path, cache=True, backed="r") #This is the path to the Count matricies generated by Velocyto
            #before appending it each barcode needs to be made unique across all patients - example: <barcode>"_"<OR date>"_"<Tumor type>
            #splitting the filepath by "/"
            items_in_file_path = file_path.split("/")
            tumor_type = items_in_file_path[7][4:] #grab everything after the "_", inclusive (disregard the "gfps" beforehand)
            #append to list
            obs_barcode_array = []
            for barcode in adata_velocyto.obs.index:
                date = barcode[:6] #Everything up to the 7th character , excluseive (disregard, ironically, the barcode)
                barcode_suffix = tumor_type + "_" + date
                barcode = barcode[15:] #take off the first information
                barcode = barcode[:-1] #remove the "x" at the end
                #print("we are adding: ", barcode_suffix, " to: ", barcode)
                barcode += barcode_suffix #add the suffix
                obs_barcode_array.append(barcode)
            adata_velocyto.obs["Patient_ID"] = obs_barcode_array #With this addition we can keep track of where each cell came from by directly accessing patient ID!
            obs_barcode_array = [] #array must be reset
            adata_loom_list.append(adata_velocyto)
    #print("This is adata_loom_list: ", adata_loom_list)
    adata_seurat = scv.read(seurat_file_path, cache=True)
    return adata_loom_list, adata_seurat

def mergeloom(file_list):
    #whatever instructions needed to focus on merging files
    combined_loom_file = AnnCollection(file_list, join_obs=None, join_vars="outer", label="dataset")
    print("this is the combined object: ", combined_loom_file)
    return combined_loom_file

def mergeadata(adata_seurat, adata_velocyto): #Assumes one ancollection can be merged with one seuratobject
    #merge along the obsveration axis
    # Gather data
    adata_seurats_barcodes_unsorted = adata_seurat.obs.index
    adata_velocytos_barcodes_unsorted = adata_velocyto.obs.index

    adata_seurats_barcodes =sorted(adata_seurats_barcodes_unsorted)
    adata_velocytos_barcodes =sorted(adata_velocytos_barcodes_unsorted)
    #print("these are the sorted adata seurats barcodes: ", adata_seurats_barcodes)
    #print("these are the sorted adata velocytos barcodes: ", adata_velocytos_barcodes)

    #Create list to store old indices for each adata
    #These are used for subsetting
    old_seurat_indices=[]
    old_velocyto_indices=[]

    #Create list to store new indices for each adata
    new_seurat_indices=[]
    new_velocyto_indices=[]
    #Create tracking variable for number of velocyto cells removed for QC
    num_vc_cells_removed = 0
    num_se_cells_removed = 0

    while_loop_range = len(adata_velocytos_barcodes)
    i=0
    while i < while_loop_range:
        try:
            # Extract and store the filename from the velocytos index
            old_seurats_index = adata_seurats_barcodes[i]
            old_velocytos_index = adata_velocytos_barcodes[i]

            seurats_index = old_seurats_index
            velocytos_index = old_velocytos_index

            filename_prefix = velocytos_index.split(":")[0].replace("_", "")
            #print("this is seurats_index BEFORE regex: ", seurats_index)
            # Remove extra from the seurats index
            seurats_index = re.sub(r"-\d+_\d+", "", seurats_index)
            seurats_index = seurats_index.replace("T_", "")

            # Remove "x" from the velocytos index
            velocytos_index = velocytos_index.split(":")[1].replace("x", "")

            #At this point, check if they are equal (if not, we are on a velocyto cell that did not pass seurat QC). If equal, proceed. Otherwise, increment just the velocyto index by one.    
            #print("this is seurats_index after regex: ", seurats_index)
            #print("this is velocytos_index after regex: ", velocytos_index)
            if (seurats_index == velocytos_index):
                # Rename the indices
                seurats_index = seurats_index + "_" + filename_prefix
                velocytos_index = velocytos_index + "_" + filename_prefix
                #Add them to the lists
                old_seurat_indices.append(old_seurats_index)
                old_velocyto_indices.append(old_velocytos_index)
                new_seurat_indices.append(seurats_index)
                new_velocyto_indices.append(velocytos_index)
                #Increment index variable i
                i+=1
            else: #they were not equal
                #skip the velocyto index, as this is a cell that did not pass QC
                if(seurats_index < velocytos_index):
                    del adata_seurats_barcodes[i]
                    num_se_cells_removed+=1
                else:
                    del adata_velocytos_barcodes[i]
                    num_vc_cells_removed+=1
        except IndexError:
            print("The loop would have broken due to an IndexError, but that is now accounted for")
            break
            #gets us out of the loop no matter what

    if(len(adata_velocytos_barcodes) > len(adata_seurats_barcodes)):
        del adata_velocytos_barcodes[-1]
        num_vc_cells_removed +=1
        print("Velocyto cell removed for not passing seurat QC")

    print(num_vc_cells_removed, " Veloctyo cells were removed due to not passing Seurats QC metrics")
    print(num_se_cells_removed, " Seurat cells were removed for unknown reasons")

    #Subset the old adatas with the old index lists
    adata_seurat_subsetted = adata_seurat[old_seurat_indices]
    adata_velocyto_subsetted = adata_velocyto[old_velocyto_indices]

    #update the index variables with the new lists
    adata_seurat_subsetted.obs.index = pd.Index(new_seurat_indices)
    adata_velocyto_subsetted.obs.index = pd.Index(new_velocyto_indices)

    #Actually call the built in anndata merge function (scv.utils.merge)
    adata = scv.utils.merge(adata_velocyto_subsetted, adata_seurat_subsetted)
    #if this funciton fails, try instantiating an anndata collection with an outer join

    print("This is the COMBINED adata object: ", adata)
    return adata

def sort_and_merge(input_seurat_object, input_loom_object_list): #The temporary fix until I learn of a more convienient way to do this, if one exists
    # Psuedocode: 
    # Input: takes in list of loom objects & seurat object
    # Initialize list of seurat subsets
    input_seurat_object_list = []
    #Initialize list of unique patient IDs
    unique_patient_ids = []
    for obs_instance in input_seurat_object: 
        #print("This is obs_instance: ", obs_instance)
        id = obs_instance.obs["Patient"][0]
        if id not in unique_patient_ids:
            unique_patient_ids.append(id)
            #print("this is unique_patient_ids: ", unique_patient_ids) #check what it looks like each time we append
            #print("this is the type of unique_patient_ids: ", type(unique_patient_ids)) #check what it looks like each time we append
    #subset seurat object by patient id
    for id in unique_patient_ids:
        # subset seurat object with patient id == id
        #print("this is the value of id: ", id)
        current_subset = subset(input_seurat_object, "Patient", id)
        # Add each to list
        input_seurat_object_list.append(current_subset)    

    #Meanwhile, the loom objects need their patient IDs updated
    for current_loom_object in input_loom_object_list:
        previous_id = current_loom_object.obs["Patient_ID"][0]
        print("this is previous id:", previous_id)
        date = previous_id[-6:]
        print("this is the extracted date: ", date)
        # Remove the matched numbers (date) from the end of the previous ID
        remaining_previous_id = previous_id[:len(previous_id) - len(date)]
        # Extract and remove the tumor type as well as the date
        tumor_type = remaining_previous_id[-5:]
        remaining_previous_id = previous_id[:len(previous_id) - len(tumor_type)]
        # Reconstruct the previous ID with the date at the front
        reconstructed_id = date #+ tumor_type + remaining_previous_id[:-6] #removes tumor type from end of the string #For now just make the ID the date as apparenly even previous ID is "rank" instead of tumor type
        print("This is the reconstructed id: ", reconstructed_id)
        #set the ["Patient_ID"][0] to the new string
        current_loom_object.obs["Patient_ID"][0] = reconstructed_id
        print("this is the new value of current_loom_object.obs[Patient_ID][0]: ", current_loom_object.obs["Patient_ID"][0])
    # Sort list of seurat objects by patient id
    sorted_adata_seurat_list = merge_sort(input_seurat_object_list, "Patient")
    # Sort list of loom objects by patient id    
    sorted_adata_loom_list = merge_sort(input_loom_object_list, "Patient_ID")
    #veriying everything went well
    print("this is the length of the sorted seurat list: ", len(sorted_adata_seurat_list), " and this is the length of the sorted loom list: ",  len(sorted_adata_loom_list))
    for i in range(len(sorted_adata_loom_list)):
        #checking patient IDs
        #took out the actual entries for now so that I can see if they are sorted approrpriately
        print("this is sorted adata_seruat_list entry: ", i,  "and here is its .obs[Patient][0] ", sorted_adata_seurat_list[i].obs["Patient"][0]) #sorted_adata_seurat_list[i],
        print("this is sorted adata_loom_list entry: " , i , "and here is its .obs[Patient_ID][0] ", sorted_adata_loom_list[i].obs["Patient_ID"][0]) #sorted_adata_loom_list[i],
        #checking barcodes
        print("Barcodes: this is sorted adata_seruat_list entry: ", i, sorted_adata_seurat_list[i], "and here are its obs_names ", sorted_adata_seurat_list[i].obs_names)
        print("Barcodes: this is sorted adata_loom_list entry: " , i ,sorted_adata_loom_list[i], "and here are its obs_names ", sorted_adata_loom_list[i].obs_names)

    merged_object_array = []
    # For (length of either list - they are the same length): 
    loom_index_incrementer = 0
    seurat_index_incrementer = 0
    while loom_index_incrementer <(len(sorted_adata_loom_list)): 
        # Merge objects located at index i in their respective lists
        # Add this merged object to a third list (of merged objects)
        print("Line 249 prints")
        if sorted_adata_loom_list[loom_index_incrementer].obs["Patient_ID"][0] in sorted_adata_seurat_list[seurat_index_incrementer].obs["Patient"][0]: #Flipped the condition around as the loom patients date should match the seurat patient ID
            merged_object_array.append(mergeadata(sorted_adata_seurat_list[seurat_index_incrementer], sorted_adata_loom_list[loom_index_incrementer]))
            print("the patients should be equal. Here is the seurat patient:", sorted_adata_seurat_list[seurat_index_incrementer].obs["Patient"][0], " and here is the loom patient: ", sorted_adata_loom_list[loom_index_incrementer].obs["Patient_ID"][0])
            seurat_index_incrementer+=1
            loom_index_incrementer+=1
        #^Commented out for ease of reading for now
        else: #for now I am assuming this means that the seurat patient pairs to a future loom patient
            print("the patients should NOT be equal. Here is the seurat patient:", sorted_adata_seurat_list[seurat_index_incrementer].obs["Patient"][0], " and here is the loom patient: ", sorted_adata_loom_list[loom_index_incrementer].obs["Patient_ID"][0])
            seurat_index_incrementer+=1

    print("This is merged_object_array: ", merged_object_array)
    # Declare an ancollection of the 3rd list
    #complete_object = AnnCollection(merged_object_array)
    #trying concat() on them
    complete_object = ad.concat(merged_object_array)
    # Return complete "object" (AnnCollection)
    return complete_object

def merge_sort(anndata_object_list, key):
    if len(anndata_object_list) <= 1:
        return anndata_object_list

    # Split the list into two halves
    mid = len(anndata_object_list) // 2
    left_half = anndata_object_list[:mid]
    right_half = anndata_object_list[mid:]

    # Recursively sort each half
    left_half = merge_sort(left_half,key)
    right_half = merge_sort(right_half,key)

    # Merge the sorted halves
    sorted_anndata_object_list = merge(left_half, right_half,key)

    return sorted_anndata_object_list

def merge(left, right,key):
    merged = []
    left_index = 0
    right_index = 0

    while left_index < len(left) and right_index < len(right):
        if left[left_index].obs[key][0] <= right[right_index].obs[key][0]:
            merged.append(left[left_index])
            left_index += 1
        else:
            merged.append(right[right_index])
            right_index += 1

    # Append remaining elements from left or right sublist
    merged.extend(left[left_index:])
    merged.extend(right[right_index:])

    return merged

def subset(adata, index, desired_value):
    # Gather data
    adata_total_barcodes = adata.obs.index #This allows us to keep every observation 
    #Create list to store old indices for each adata
    #These are used for subsetting
    old_adata_indices=[]

    #Create list to store new indices for each adata
    new_adata_indices=[]
    new_adata_indicesII=[]
    #Create tracking variable for number of non CD4T cells removed
    num_adata_cells_removed = 0

    #test - sorting based on cell type
    sorted_indices = adata.obs[index].argsort()

    while_loop_range = len(adata_total_barcodes)
    i=0
    while i < while_loop_range:
        curent_barcode = adata_total_barcodes[i]
        #print("this is adata[i][:].obs[index].iloc[0]: ", adata[i][:].obs[index].iloc[0], "this is the desired value: ", desired_value, "and this is the barcode: ", curent_barcode)
        if (adata[i][:].obs[index].iloc[0] == desired_value):
            new_adata_indices.append(curent_barcode)
        i+=1
    adata_subsetted = adata[new_adata_indices]
    #print("this is new_adata_indices: ", new_adata_indices)
    #print("the new subseted object is: ", adata_subsetted)
    return adata_subsetted

def recover_and_velocity(adata, velocity_mode = "stochastic"):
    #adata: the anndata object
    #velocity_mode: stochastic or dynamical
    scv.tl.recover_dynamics(adata, n_jobs=8) 
    scv.tl.velocity(adata, mode=velocity_mode)
    scv.tl.velocity_graph(adata)
    return adata

def rank_and_write_velocity_genes(adata, diffkinetics=False):
    if (diffkinetics == False): #dont use differential kinetics
        scv.tl.rank_velocity_genes(adata, groupby="Cell_type", min_likelihood=0) #Ranking of top genes by cell type #Experiment: explicity saying the number of genes to rank
    else:
        scv.tl.velocity(adata, groupby="Cell_type") 
    df_name = scv.DataFrame(adata.uns["rank_velocity_genes"]["names"]) 
    print(df_name)#If I need to print a specifc column I can do so by saying print (df_name[0:][<column of interest>])
    #velocity gene scores - will change this over to adata subsets by cluster
    # adata_CD4T = adata.obs["CD_8_T"]
    # df_velocity_gene_scores_CD4T = adata.varm["velocity_score"]
    return df_name

def runDK(adata):
    #Will be implemented if needed. But, with only one cell type, this seems less likely. Maybe by tumor type?
    return adata

def plot_RNA_Velocity_results(adata):
    print("This is adata once it is read into plot_RNA_Velocity_results: ", adata)
    whole_population_array = ["CXCL13","KLF7","NKG7","TBX21","IGHD","MAL","TRAT1","SPI1","C1orf21","CD8A","ATP10A","IL4R","HLA-DRB1","GK5","ANXA1","COTL1","AL138963.4","CTLA4","APLP2","FAM30A","S1PR5","LRP1","DAPK1","CD163","KIR3DL1","SOX5","CST3"]
    genes_of_interest_24_2_29 = ["LEF1", "FMN1", "DNAJB1", "HSP90AA1", "ANKRD55", "ANXA1", "FAAH2", "FOSB", "CCR7", "CYSLTR1", "TBXAS1", "PDE7B", "TOX", "AGFG1", "HIPK2"]
    four_significant_genes = ["IGHG1", "AL591845.1", "VCAM1", "AL138963.4"] #Genes with likelihood >.1
    scv.pl.velocity(adata, genes_of_interest_24_2_29, ncols=5, save=plotoutputfilepath+"TEST_BRAIN"+date+"_phase_map_"+stochastic_or_dynamical+".png")
    scv.pl.velocity(adata, four_significant_genes, save=plotoutputfilepath+"TEST_BRAIN"+date+"_phase_map_likely_genes_"+stochastic_or_dynamical+".png")
    scv.pl.hist(adata.var["fit_likelihood"], axvline=.05, kde=True, exclude_zeros=True, colors=["grey"], save=plotoutputfilepath+"TEST_BRAIN"+date+"_likelihood_histogram_no_0_"+stochastic_or_dynamical+".png")
    scv.pl.hist(adata.var["fit_likelihood"], axvline=.05, kde=True, exclude_zeros=False, colors=["grey"], save=plotoutputfilepath+"TEST_BRAIN"+date+"_likelihood_histogram_with_0_"+stochastic_or_dynamical+".png")
    #plot_projection(adata, plotoutputfilepath+"TEST_BRAIN"+date+"_projection_"+stochastic_or_dynamical+".png", "Cell_type") #Deprecated now, handled in the combined kernel function
    vk = cr.kernels.VelocityKernel(adata)
    vk.compute_transition_matrix()
    ck = cr.kernels.ConnectivityKernel(adata)
    ck.compute_transition_matrix()
    return vk, ck

def pseudotime(adata):
    #Computing palantir can go up here

    #DPT pseudotime
    sc.tl.diffmap(adata)
    print(adata.obsm["X_diffmap"][:, 3].argmax())
    root_ixs = adata.obsm["X_diffmap"][:, 3].argmax()
    scv.pl.scatter(adata, basis="diffmap", c=["Cell_type", root_ixs], legend_loc="right", components=["2, 3"],save=plotoutputfilepath+"TEST_DIFFMAP"+date+stochastic_or_dynamical+".png")
    adata.uns["iroot"] = root_ixs
    #compute DPT
    sc.tl.dpt(adata)
    print(adata.obsm["X_diffmap"])
    sc.pl.embedding(adata,basis="umap",color=["dpt_pseudotime"],color_map="gnuplot2", save=plotoutputfilepath+"TEST_DPT_Embedding"+date+stochastic_or_dynamical+".png")#"palantir_pseudotime"

    #plotting violin plots of trajectories:
    trajectory_cell_types = ["T01.CD4_IL7R", "T05.CD4_CD40LG", "T06.CD4_Treg", "T07.CD4_Treg", "T08.CD4_KLF2","T12.CD4_CXCL13"] #as many as wanted
    # plot type 1
    mask = np.in1d(adata.obs["Cell_type"], trajectory_cell_types)
    sc.pl.violin(adata[mask], keys=["dpt_pseudotime"], groupby="Cell_type", rotation=-90, order=trajectory_cell_types, save=plotoutputfilepath+"TEST_PSEUDOTIME_VIOLIN"+date+stochastic_or_dynamical+".png") #"palantir_pseudotime"
    pk_dpt = cr.kernels.PseudotimeKernel(adata, time_key="dpt_pseudotime") #By having different kernels returned, it is possible to then combine the kernels later and average out the pseudotimes
    #pk_palantir = cr.kernels.PseudotimeKernel(adata, time_key="palantir_pseudotime") 
    pk_dpt.compute_transition_matrix()
    #pk_palantir.compute_transition_matrix()
    return pk_dpt#, pk_palantir #uncomment when ready

def combine_kernels(RNA_Velocity_Kernel, RNA_Velocity_Fraction, Connecticity_Kernel, Connectivity_Fraction, DPT_Pseudotime_kernel, DPT_Pseudotime_Fraction):
    combined_kernel= RNA_Velocity_Kernel*RNA_Velocity_Fraction + Connecticity_Kernel*Connectivity_Fraction + DPT_Pseudotime_kernel*DPT_Pseudotime_Fraction
    print(combined_kernel)
    combined_kernel.compute_transition_matrix()
    combined_kernel.plot_projection(basis="umap", recompute=True, save=plotoutputfilepath+"Projection_With_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png", color="Cell_type") #may have to make this more specific in a bit
    #combined_kernel.plot_random_walks(start_ixs={"Cell_type": "T01.CD4 IL7R"}, max_iter=200, seed=0, save="Projection_With_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png")
    return combined_kernel

def combined_kernel_analysis(combined_kernel, adata, RNA_Velocity_Fraction, Connectivity_Fraction, DPT_Pseudotime_Fraction):
    #Initialize
    current_model = GPCCA(combined_kernel)
    print(current_model)

    #predict and plot all states
    current_model.fit(n_states=10, cluster_key="Cell_type")
    current_model.plot_macrostates(which="all", save=plotoutputfilepath+"_All_States_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #Have to check if theres a save method
    #plot terminal states
    current_model.predict_terminal_states(method="top_n", n_states=6)
    current_model.plot_macrostates(which="terminal",  save=plotoutputfilepath+"_Terminal_States_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #have to see if theres a save method

    #compute & plot fate probabilities
    current_model.compute_fate_probabilities()
    current_model.plot_fate_probabilities(legend_loc="right",  save=plotoutputfilepath+"_Fate_Probabilities_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png")
    #cr.pl.circular_projection(adata, keys="Cell_type", legend_loc="right", save=plotoutputfilepath+"_Circular_Projection_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #Will see how to fix cricular projection later
    #mono_drivers = current_model.compute_lineage_drivers(lineages="Mono_1_1")
    #^Convert this to for loop for each cluster
    lineage_array=[]
    print(current_model) #learn what the structure of this object is
    for lineage in current_model.terminal_states:#current_model.macrostates: #I think its this? #kenel, adata, or whatever it is
        #drivers_per_cluster.to_csv() #may have to cast pending on type
        lineage_array.append(lineage)
    print("This is lineage_array: ", lineage_array)
    print("This is current_model.terminal_states ", current_model.terminal_states, " and here is its type: ", type(current_model.terminal_states))
    print("This is current_model.terminal_states.categories ", current_model.terminal_states.categories, " and here is its type: ", type(current_model.terminal_states.categories))
    drivers_per_cluster= current_model.compute_lineage_drivers(lineages=current_model.terminal_states.categories)
    print(type(drivers_per_cluster))
    #Visulize expression trends
    model = cr.models.GAM(adata) #GAM instead of GAMR to stay in Python
    gene_array = ["CXCL13", "CD34", "ANXA1", "CD14", "KLF2", "IL7R"]
    pseudotime_type="dpt_pseudotime"
    cr.pl.gene_trends( adata, model=model, data_key="X", genes=gene_array, same_plot=True, ncols=2, time_key=pseudotime_type, hide_cells=True, save=plotoutputfilepath+"_Gene_Trends_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #have to see what to replace magic imputed data with
    cr.pl.heatmap( adata, model=model, data_key="X", genes=gene_array, lineages=lineage_array, time_key=pseudotime_type, cbar=False, show_all_genes=True, save=plotoutputfilepath+"_Gene_Heatmap_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #also, I will have to see if I have to write a save function for these too

def main(argument_list):
    #scv.logging.print_versions() #Uncomment when needed for bug reports
    print("This gets printed before anything happens in Python")
    seurat_sample, bmet_samples_to_iterate, gbm_samples_to_iterate = argument_list[1], argument_list[2], argument_list[3] #index 0 is just "-c"
    #argument_list[:len(argument_list/2)] #take first two elements (indcies 0 and 1) #We demand a list of loom files to be paired with a seurat object file (so then we can just take 1/2 of the list) for each sample category (example: bmet and gbm)
    #argument_list[len(argument_list/2):-1] #second half of list is bmet stuff, except for last item
    try: #accessing whole_object from today
        whole_object = scv.read(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip", cache=True)
        print("Started from previously written anndata")
    except Exception:
        print("Generating new anndata object")
        sample_list = [bmet_samples_to_iterate, gbm_samples_to_iterate]
        velocyto_adata_list, seurat_object = readdata(sample_list, seurat_sample)
        collected_adata = sort_and_merge(seurat_object, velocyto_adata_list)
        whole_object = ad.AnnData(collected_adata)
        whole_object.write(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip")
    try:
        open(csvoutputfilepath+date+"genes_and_likelihood.csv")
        print("velocity calculations skipped")
    except Exception:
        print("Performing velocity calculations")
        whole_object = recover_and_velocity(whole_object)
        top_gene_df = rank_and_write_velocity_genes(whole_object) #Dataframe containing top genes per cluster
        #whole_object = computeDEGs(whole_object) #Likely not needed - the previous function takes care of this
        top_gene_df.to_csv(csvoutputfilepath+date+"top_genes_by_cluster.csv")
        #Next up: Add this, which will show likelihoods
        df_likelihood = whole_object.var
        df_likelihood = [df_likelihood["velocity_genes"] == True] #df_likelihood[(df_likelihood["fit_likelihood"] > .1) & #Extra from a comparison to only show genes above .1, which I obviously dont want atm
        # kwargs = dict(xscale="log", fontsize=16) #commenting out for now so I can get everything else
        # with scv.GridSpec(ncols=3) as pl:
        #     pl.hist(df_likelihood["fit_alpha"], xlabel="transcription rate", **kwargs) # I think its giving an error based on indices being Strs...
        #     pl.hist(df_likelihood["fit_beta"] * df_likelihood["fit_scaling"], xlabel="splicing rate", xticks=[.1, .4, 1], **kwargs)
        #     pl.hist(df_likelihood["fit_gamma"], xlabel="degradation rate", xticks=[.1, .4, 1], **kwargs)
        scv.get_df(whole_object, "fit*", dropna=True).to_csv(csvoutputfilepath+date+"genes_and_likelihood.csv")
    #whole_object = runDK(whole_object) #Currenly not using so I am deprecating. Ask ben before removing.
        whole_object.write(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip")
    try:
        RNA_Velocity_kernel = cr.kernels.PrecomputedKernel.read(objectoutputfilepath+"Kernels/"+"TEST"+date+"RNA_Velocity_Kernel.h5ad")
        connectivity_kernel = cr.kernels.PrecomputedKernel.read(objectoutputfilepath+"Kernels/"+"TEST"+date+"Connectivity_Kernel.h5ad")
        print("RNA Velocity kernel instantiation skipped")
    except Exception:
        print("Performing RNA Velocity kernel instantiation")
        RNA_Velocity_kernel, connectivity_kernel = plot_RNA_Velocity_results(whole_object)
        RNA_Velocity_kernel.write(objectoutputfilepath+"Kernels/"+"TEST"+date+"RNA_Velocity_Kernel.h5ad")
        connectivity_kernel.write(objectoutputfilepath+"Kernels/"+"TEST"+date+"Connectivity_Kernel.h5ad")
    try:
        dpt_pseudotime_kernel = cr.kernels.PrecomputedKernel.read(objectoutputfilepath+"Kernels/"+"TEST"+date+"DPT_Pseudotime_Kernel.h5ad")
        #palantir_pseudotime_kernel = cr.kernels.PrecomputedKernel.read(objectoutputfilepath+"Kernels/"+"TEST"+date+"Palantir_Pseudotime_Kernel.h5ad")
        print("pseudotime kernel instantiation skipped")
    except Exception:
        print("Performing pseudotime kernel instantiation")
        dpt_pseudotime_kernel = pseudotime(whole_object) #add palantir_pseudotime_kernel when ready
        dpt_pseudotime_kernel.write(objectoutputfilepath+"Kernels/"+"TEST"+date+"DPT_Pseudotime_Kernel.h5ad")
    print("Combined Kernel analysis starting")
    current_combined_kernel = combine_kernels(RNA_Velocity_kernel, .45, connectivity_kernel, .1, dpt_pseudotime_kernel, .45) #, palantir_psuedotime_kernel, 0 #add in when ready
    combined_kernel_analysis(current_combined_kernel, whole_object, .45, .1, .45)
    whole_object.write(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip")
    print("The program executed successfully!")

main(sys.argv)
' "$seurat_object_file" "$bmet_loom_files" "$gbm_loom_files"
erzakiev commented 7 months ago

also facing issue ImportError: cannot import name 'deprecated_arg_names' from 'scanpy._utils' when trying to import cellrank

Code ```python import scvelo as scv import scanpy as sc import cellrank as cr import numpy as np import pandas as pd import anndata as ad --------------------------------------------------------------------------- ImportError Traceback (most recent call last) Cell In[14], line 3 1 import scvelo as scv 2 import scanpy as sc ----> 3 import cellrank as cr 4 import numpy as np 5 import pandas as pd File [~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/__init__.py:3](http://localhost:8888/lab/tree/~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/__init__.py#line=2) 1 from importlib import metadata ----> 3 from cellrank import datasets, estimators, kernels, logging, models, pl 4 from cellrank._utils._lineage import Lineage 5 from cellrank.settings import settings File [~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/pl/__init__.py:2](http://localhost:8888/lab/tree/~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/pl/__init__.py#line=1) 1 from cellrank.pl._aggregate_fate_probs import aggregate_fate_probabilities ----> 2 from cellrank.pl._circular_projection import circular_projection 3 from cellrank.pl._cluster_trends import cluster_trends 4 from cellrank.pl._gene_trend import gene_trends File [~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/pl/_circular_projection.py:17](http://localhost:8888/lab/tree/~/miniforge3/envs/scFates/lib/python3.11/site-packages/cellrank/pl/_circular_projection.py#line=16) 14 from matplotlib.colors import LinearSegmentedColormap, LogNorm 16 from anndata import AnnData ---> 17 from scanpy._utils import deprecated_arg_names 19 from cellrank import logging as logg 20 from cellrank._utils import Lineage ImportError: cannot import name 'deprecated_arg_names' from 'scanpy._utils' ([/Users/administrateur/miniforge3/envs/scFates/lib/python3.11/site-packages/scanpy/_utils/__init__.py](http://localhost:8888/Users/administrateur/miniforge3/envs/scFates/lib/python3.11/site-packages/scanpy/_utils/__init__.py)) ```
jwalewski commented 7 months ago

Seems like we are running into the issue described in https://github.com/theislab/cellrank/issues/1183.

Meanwhile, the "lineages" parameter question remains open.

michalk8 commented 7 months ago

Thanks a lot for spotting this @jwalewski , will take a look at this.

jwalewski commented 7 months ago

Hey, once again checking in on this. So, with updating everything to newest versions cellrank, the driver genes can now be put in a CSV. However, the plots for gene trends are still not working, with the following error:

Traceback (most recent call last):
  File "<string>", line 91, in <module>
  File "<string>", line 87, in main
  File "<string>", line 78, in combined_kernel_analysis
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_utils.py", line 1586, in _genesymbols
    return wrapped(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/pl/_gene_trend.py", line 184, in gene_trends
    probs = Lineage.from_adata(adata, backward=backward)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_lineage.py", line 850, in from_adata
    raise KeyError(f"Unable to find lineage data in `adata.obsm[{key!r}]`.")
KeyError: "Unable to find lineage data in `adata.obsm['lineages_fwd']`."

I've tried hunting down this bug too, and it's ultimately come down (at least I think) to the _write_fate_probabilites function.

Here's the code:

import sys
sys.path.append("/home/jw2894/.conda/envs/CellRank240328/bin/")
import numpy as np
import os
from scipy import io
from scipy.sparse import coo_matrix, csr_matrix
import pandas as pd
import cellrank as cr
from cellrank.estimators import GPCCA
import scvelo as scv
import scanpy as sc
import warnings
import re
import gdown
import anndata as ad
from anndata.experimental.multi_files import AnnCollection
from sklearn.model_selection import train_test_split
from sklearn.decomposition import IncrementalPCA
#For palantir
import matplotlib
import matplotlib.pyplot as plt

stochastic_or_dynamical="dynamical"
date="24_03_28"
outputfilepath="/gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Output/"
plotoutputfilepath=outputfilepath+"Plots/"
csvoutputfilepath=outputfilepath+"CSVs/"
objectoutputfilepath=outputfilepath+"Objects/"

import warnings
warnings.simplefilter("ignore", category=UserWarning)
cr.logging.print_versions() #comment on and off as necessary

def combined_kernel_analysis(combined_kernel, adata, RNA_Velocity_Fraction, Connectivity_Fraction, DPT_Pseudotime_Fraction):
    #Initialize
    current_model = GPCCA(combined_kernel)
    print(current_model)

    #predict and plot all states
    current_model.fit(n_states=10, cluster_key="Cell_type")
    current_model.plot_macrostates(which="all", save=plotoutputfilepath+"_All_States_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #Have to check if theres a save method
    #plot terminal states
    print("about to predict terminal states")
    print(adata)
    current_model.predict_terminal_states(method="top_n", n_states=6)
    current_model.set_terminal_states(["T05.CD4_CD40LG_1", "T05.CD4_CD40LG_2", "T12.CD4_CXCL13_1", "T12.CD4_CXCL13_2", "T06.CD4_Treg_2", "T06.CD4_Treg_3"])
    print("just predicted terminal states")
    print(adata)
    current_model.plot_macrostates(which="terminal",  save=plotoutputfilepath+"_Terminal_States_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #have to see if theres a save method

    #compute & plot fate probabilities
    current_model.compute_fate_probabilities()
    current_model.plot_fate_probabilities(legend_loc="right",  save=plotoutputfilepath+"_Fate_Probabilities_of_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png")
    #cr.pl.circular_projection(adata, keys="Cell_type", legend_loc="right", save=plotoutputfilepath+"_Circular_Projection_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #Will see how to fix cricular projection later
    #mono_drivers = current_model.compute_lineage_drivers(lineages="Mono_1_1")
    #^Convert this to for loop for each cluster
    lineage_array=[]
    print(current_model) #learn what the structure of this object is
    for lineage in current_model.terminal_states:#current_model.macrostates: #I think its this? #kenel, adata, or whatever it is
        #drivers_per_cluster.to_csv() #may have to cast pending on type
        if not pd.isna(lineage):
            if lineage not in lineage_array:
                lineage_array.append(lineage)
    #print("This is lineage_array: ", lineage_array)
    #print("This is the type of lineage_array: ", type(lineage_array))
    #print("This is current_model.terminal_states ", current_model.terminal_states, " and here is its type: ", type(current_model.terminal_states))
    #print("This is current_model.terminal_states.categories ", current_model.terminal_states.categories, " and here is its type: ", type(current_model.terminal_states.categories))
    drivers_per_cluster= current_model.compute_lineage_drivers(lineages=["T05.CD4_CD40LG_1", "T05.CD4_CD40LG_2", "T12.CD4_CXCL13_1", "T12.CD4_CXCL13_2", "T06.CD4_Treg_2", "T06.CD4_Treg_3"])
    print(type(drivers_per_cluster)) #<class pandas.core.frame.DataFrame>
    drivers_per_cluster.to_csv(csvoutputfilepath+date+"_Driver_Genes_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".csv")
    #Visulize expression trends
    model = cr.models.GAM(adata) #GAM instead of GAMR to stay in Python
    gene_array = ["CXCL13", "ANXA1", "CD14", "KLF2", "IL7R"] #"CD34" not found
    pseudotime_type="dpt_pseudotime"
    print("this is adata.obsm: ", adata.obsm)
    #print("this is model.obsm: ", model.obsm)
    cr.pl.gene_trends( adata, model=model, data_key="X", genes=gene_array, same_plot=True, ncols=2, time_key=pseudotime_type, hide_cells=True, save=plotoutputfilepath+"_Gene_Trends_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #have to see what to replace magic imputed data with
    #cr.pl.heatmap( adata, model=model, data_key="X", genes=gene_array, lineages=lineage_array, time_key=pseudotime_type, cbar=False, show_all_genes=True, save=plotoutputfilepath+"_Gene_Heatmap_"+str(RNA_Velocity_Fraction*100)+"% RNA Velocity"+str(Connectivity_Fraction*100)+"% Connectivity"+ str(DPT_Pseudotime_Fraction*100)+"% DPT Psuedotime"+date+stochastic_or_dynamical+".png") #also, I will have to see if I have to write a save function for these too

def main(argument_list):
    print("This gets printed before anything happens in Python")
    print("Combined Kernel analysis starting")
    whole_object = scv.read(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip", cache=True)
    current_combined_kernel = cr.kernels.PrecomputedKernel.read(objectoutputfilepath+"Kernels/"+"TEST"+date+"Combined_Kernel.h5ad")
    combined_kernel_analysis(current_combined_kernel, whole_object, .45, .1, .45)
    whole_object.write(objectoutputfilepath+"TEST"+date+".h5ad", compression="gzip")
    print("The program executed successfully!")

main(sys.argv)

And the _write_fate_probabilities (with print statements added):

    @logger
    @shadow
    def _write_fate_probabilities(
        self: FateProbsProtocol,
        fate_probs: Optional[Lineage],
        params: Mapping[str, Any] = types.MappingProxyType({}),
    ) -> str:
        # fmt: off
        key = Key.obsm.fate_probs(self.backward)
        print("write fate probabilites. Key is: ", key)
        self._set("_fate_probabilities", self.adata.obsm, key=key, value=fate_probs)
        self._write_lineage_priming(None, log=False)
        print("self.adata.obsm is (after setting fate probabilites): ", self.adata.obsm)
        self.params[key] = dict(params)
        # fmt: on
        print("line 517 reached")
        return "This is a test return string"
        #return (f"Adding `adata.obsm[{key!r}]`\n       `.fate_probabilities`\n    Finish")
        #return f"Adding `adata.obsm[{key!r}]`\n" f"       `.fate_probabilities`\n" f"    Finish" #apparenlty this line is invalid python syntax, which is not surprising

Here's the output:

cellrank==2.0.3 scanpy==1.10.0 anndata==0.10.6 numpy==1.26.4 numba==0.59.1 scipy==1.12.0 pandas==2.2.1 pygpcca==1.0.4 scikit-learn==1.1.3 statsmodels==0.14.1 scvelo==0.0.0 pygam==0.9.1 matplotlib==3.8.3 seaborn==0.13.2
This gets printed before anything happens in Python
Combined Kernel analysis starting
GPCCA[kernel=(0.45 * VelocityKernel[n=22062] + 0.1 * ConnectivityKernel[n=22062] + 0.45 * PseudotimeKernel[n=22062]), initial_states=None, terminal_states=None]
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd
line 517 reached
saving figure to file /gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Output/Plots/_All_States_of_45.0% RNA Velocity10.0% Connectivity45.0% DPT Psuedotime24_03_28dynamical.png
about to predict terminal states
AnnData object with n_obs × n_vars = 22062 × 2000
    obs: 'Patient_ID', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seq_folder', 'nUMI', 'nGene', 'barcode', 'ID', 'cdr3_TRA', 'cdr3_nt_TRA', 'cdr3_TRB', 'cdr3_nt_TRB', 'Chains', 'clonotype', 'n.exp.hkgenes', 'Patient', 'Tissue', 'Type', 'OR_date', 'log10GenesPerUMI', 'mitoRatio', 'TCR', 'Processing', 'batch_10X', 'batch_seq', 'RNA_snn_res.0.5', 'seurat_clusters', 'TCR_genes1', 'RNA_snn_res.0.8', 'RNA_snn_res.1', 'RNA_snn_res.1.2', 'RNA_snn_res.2', 'Lineage', 'Cell_type'
    obsm: 'X_harmony', 'X_pca', 'X_umap'
    layers: 'ambiguous', 'matrix', 'spliced', 'unspliced'
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships
line 517 reached
Self.set_terminal_states is:  GPCCA[kernel=(0.45 * VelocityKernel[n=22062] + 0.1 * ConnectivityKernel[n=22062] + 0.45 * PseudotimeKernel[n=22062]), initial_states=None, terminal_states=['T05.CD4_CD40LG_1', 'T05.CD4_CD40LG_2', 'T06.CD4_Treg_2', 'T06.CD4_Treg_3', 'T12.CD4_CXCL13_1', 'T12.CD4_CXCL13_2']]
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships
line 517 reached
just predicted terminal states
AnnData object with n_obs × n_vars = 22062 × 2000
    obs: 'Patient_ID', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seq_folder', 'nUMI', 'nGene', 'barcode', 'ID', 'cdr3_TRA', 'cdr3_nt_TRA', 'cdr3_TRB', 'cdr3_nt_TRB', 'Chains', 'clonotype', 'n.exp.hkgenes', 'Patient', 'Tissue', 'Type', 'OR_date', 'log10GenesPerUMI', 'mitoRatio', 'TCR', 'Processing', 'batch_10X', 'batch_seq', 'RNA_snn_res.0.5', 'seurat_clusters', 'TCR_genes1', 'RNA_snn_res.0.8', 'RNA_snn_res.1', 'RNA_snn_res.1.2', 'RNA_snn_res.2', 'Lineage', 'Cell_type'
    obsm: 'X_harmony', 'X_pca', 'X_umap'
    layers: 'ambiguous', 'matrix', 'spliced', 'unspliced'
saving figure to file /gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Output/Plots/_Terminal_States_of_45.0% RNA Velocity10.0% Connectivity45.0% DPT Psuedotime24_03_28dynamical.png

  0%|          | 0/6 [00:00<?, ?/s]
 33%|███▎      | 2/6 [00:00<00:00,  9.04/s]
100%|██████████| 6/6 [00:00<00:00, 17.41/s]
100%|██████████| 6/6 [00:00<00:00, 15.92/s]
Traceback (most recent call last):
  File "<string>", line 91, in <module>
  File "<string>", line 87, in main
  File "<string>", line 78, in combined_kernel_analysis
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_utils.py", line 1586, in _genesymbols
    return wrapped(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/pl/_gene_trend.py", line 184, in gene_trends
    probs = Lineage.from_adata(adata, backward=backward)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_lineage.py", line 850, in from_adata
    raise KeyError(f"Unable to find lineage data in `adata.obsm[{key!r}]`.")
KeyError: "Unable to find lineage data in `adata.obsm['lineages_fwd']`."
Line 238 is run
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: X_harmony, X_pca, X_umap, X_diffmap, T_fwd_umap, schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships, lineages_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships, lineages_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships, lineages_fwd
line 517 reached
write fate probabilites. Key is:  lineages_fwd
self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships, lineages_fwd
line 517 reached
saving figure to file /gpfs/gibbs/pi/hafler/jw2894/Projects_Post_24_2_20/CellRank/Data_Output/Plots/_Fate_Probabilities_of_45.0% RNA Velocity10.0% Connectivity45.0% DPT Psuedotime24_03_28dynamical.png
GPCCA[kernel=(0.45 * VelocityKernel[n=22062] + 0.1 * ConnectivityKernel[n=22062] + 0.45 * PseudotimeKernel[n=22062]), initial_states=None, terminal_states=['T05.CD4_CD40LG_1', 'T05.CD4_CD40LG_2', 'T06.CD4_Treg_2', 'T06.CD4_Treg_3', 'T12.CD4_CXCL13_1', 'T12.CD4_CXCL13_2']]
line 166 reached
<class 'pandas.core.frame.DataFrame'>
this is adata.obsm:  AxisArrays with keys: X_harmony, X_pca, X_umap
michalk8 commented 7 months ago

@jwalewski I assume this is the line which fails

drivers_per_cluster = current_model.compute_lineage_drivers(lineages=current_model.terminal_states.categories)

The Lineages object only knows how to handle basic types like int/str/list/tuple/numpy.ndarray, but the current_model.terminal_states.categories is as pandas.Index. Converting this to, e.g. list will work, see below:

image
jwalewski commented 7 months ago

Okay, this works now, but in a similar way I get a similar issue with:

Traceback (most recent call last):
  File "<string>", line 487, in <module>
  File "<string>", line 483, in main
  File "<string>", line 424, in combined_kernel_analysis
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_utils.py", line 1586, in _genesymbols
    return wrapped(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/pl/_gene_trend.py", line 184, in gene_trends
    probs = Lineage.from_adata(adata, backward=backward)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jw2894/.conda/envs/CellRank240328/lib/python3.11/site-packages/cellrank/_utils/_lineage.py", line 850, in from_adata
    raise KeyError(f"Unable to find lineage data in `adata.obsm[{key!r}]`.")
KeyError: "Unable to find lineage data in `adata.obsm['lineages_fwd']`."

Despite:

self.adata.obsm is (after setting fate probabilites):  AxisArrays with keys: schur_vectors_fwd, macrostates_fwd_memberships, term_states_fwd_memberships, lineages_fwd

Should I keep discussing this here? or create a new issue?