Update gene set enrichment functionality

adamklie opened 3 weeks ago

New prerank implementation:

def perform_prerank(
    loadings: pd.DataFrame,
    geneset: Union[List[str], str, Dict[str, str]],
    n_jobs: int = 1,
    low_cutoff: float = -np.inf,
    n_top: Optional[int] = None,
) -> pd.DataFrame:
    """Run GSEA prerank on each gene program in the loadings matrix.

    loadings : pd.DataFrame
        DataFrame of feature loadings for each program.
    geneset : str
        Name of the set to run GSEA on.
    n_jobs : int
        Number of parallel jobs to run.
    low_cutoff : float
        Remove features with loadings at or below this value.
    n_top : int
        Take the top n features with the highest loadings.

        DataFrame of GSEA results sorted by program name and FDR q-value. Includes the following columns:
        - program_name: name of the program
        - Term: gene set name
        - ES: enrichment score
        - NES: normalized enrichment score
        - NOM p-val: nominal p-value (from the null distribution of the gene set)
        - FDR q-val: adjusted False Discovery Rate
        - FWER p-val: Family wise error rate p-values
        - Gene %: percent of gene set before running enrichment peak (ES)
        - Lead_genes: leading edge genes (gene hits before running enrichment peak)
        - tag_before: number of genes in gene set
        - tag_after: number of genes matched to the data

    # Run GSEA prerank for each column of loadings (each cell program)
    pre_res = pd.DataFrame()
    for i in tqdm(loadings.columns, desc='Running GSEA', unit='programs'):

        # Filter out low loadings
        temp_loadings = loadings[i][(loadings[i] > low_cutoff)]

        # Take top n features if specified
        if n_top is not None:
            temp_loadings = temp_loadings.sort_values(ascending=False).head(n_top)
            if len(temp_loadings) < n_top:
                logging.warning(f"Program {i} has less than {n_top} features after filtering. Only {len(temp_loadings)} features will be used.")

        # Run GSEA prerank
        temp_res = gp.prerank(

        # Post-process results
        temp_res['Gene %'] = temp_res['Gene %'].apply(lambda x: float(x[:-1]))
        temp_res['tag_before'] = temp_res['Tag %'].apply(lambda x: int(x.split('/')[0]))
        temp_res['tag_after'] = temp_res['Tag %'].apply(lambda x: int(x.split('/')[1]))
        temp_res.drop(columns=['Tag %'], inplace=True)
        if 'Name' in temp_res.columns and temp_res['Name'][0] == "prerank":
            temp_res['Name'] = i
        temp_res.rename(columns={'Name': 'program_name'}, inplace=True)
        temp_res = temp_res.sort_values(['program_name', 'FDR q-val'])
        pre_res = pd.concat([pre_res, temp_res], ignore_index=True)

    return pre_res
New fisher implementation:

def perform_fisher_enrich(
    """Run Fisher enrichment on each gene program in the loadings matrix.

    loadings : pd.DataFrame
        DataFrame of feature loadings for each program.
    geneset : dict
        Dictionary of gene sets.
    n_top : int
        Number of top features to take.

    ['Gene_set', 'Term', 'P-value', 'Adjusted P-value', 'Odds Ratio',
       'Combined Score', 'Genes', 'program_name', 'overlap_numerator',
        DataFrame of Fisher enrichment results sorted by program name and adjusted p-value. Includes the following columns:
        - program_name: name of the program
        - Term: gene set name
        - P-value: Fisher's exact test p-value
        - Adjusted P-value: adjusted p-value
        - Odds Ratio: odds ratio
        - Combined Score: combined score
        - Genes: genes in the gene set
        - overlap_numerator: number of overlapping genes
        - overlap_denominator: number of genes in the gene set
    - Parallelize across programs

    # Find the intersection of genes present in the mudata object and in the library
    background_genes = set(value for sublist in geneset.values() for value in sublist)

    enr_res = pd.DataFrame()
    for i in tqdm(loadings.columns, desc='Running Fisher enrichment', unit='programs'):
        gene_list = list(loadings[i].sort_values(ascending=False).head(n_top).index)
        temp_res = gp.enrich(
        temp_res["program_name"] = i
        enr_res = pd.concat([enr_res, temp_res], ignore_index=True)
    enr_res['overlap_numerator'] = enr_res['Overlap'].apply(lambda x: int(x.split('/')[0]))
    enr_res['overlap_denominator'] = enr_res['Overlap'].apply(lambda x: int(x.split('/')[1]))
    enr_res.drop(columns=['Overlap', 'Gene_set'], inplace=True)
    enr_res = enr_res[["program_name"] + [col for col in enr_res.columns if col != "program_name"]]
    enr_res = enr_res.sort_values(['program_name', 'Adjusted P-value'])

    return enr_res
Proposed column renames:

if gene_set_enrichment_config["method"] == "fisher":
        pre_res = pre_res.rename(columns={"Term": "term", "P-value": "pval", "Adjusted P-value": "adj_pval", "Odds Ratio": "enrichment", "Genes": "genes"})
    elif gene_set_enrichment_config["method"] == "gsea":
        pre_res = pre_res.rename(columns={"Term": "term", "NOM p-val": "pval", "FDR q-val": "adj_pval", "NES": "enrichment", "Lead_genes": "genes"})