KrishnaswamyLab / scprep

A collection of scripts and tools for loading, processing, and handling single cell data.
MIT License
72 stars 19 forks source link

Helper function to run scanpy louvain / leiden and produce a desired number of clusters #102

Open dburkhardt opened 4 years ago

dburkhardt commented 4 years ago

Is your feature request related to a problem? Please describe. Want to use louvain and leiden to produce a specific number of clusters

Describe the solution you'd like

def find_n_clusters(adata, n_clusters, method='louvain', r_start=0.01, r_stop=10, tol=0.001):
    '''
    Helper function to run louvain and leiden clustering through scanpy and return a desired number
    of clusters. Requires scanpy to be installed.
    '''
    if method == 'louvain':
        cluster_func = sc.tl.louvain
    elif method == 'leiden':
        cluster_func = sc.tl.leiden
    else:
        raise ValueError('No such clustering method: {}'.format(method))
    if r_stop - r_start < tol:
        cluster_func(adata,resolution=r_start)
        return adata.obs[method].astype(int)
    # Check start
    cluster_func(adata,resolution=r_start)
    n_start = len(np.unique(adata.obs[method]))
    if n_start == n_clusters:
        return adata.obs[method].astype(int)
    elif n_start > n_clusters:
        raise ValueError('r_start is too large. Got: {}'.format(r_start))
    # Check end
    cluster_func(adata, resolution=r_stop)
    n_end = len(np.unique(adata.obs[method]))
    if n_end == n_clusters:
        return adata.obs[method].astype(int)
    elif n_end < n_clusters:
        raise ValueError('r_stop is too small. Got: {}'.format(r_stop))
    # Check mid
    r_mid = r_start + ((r_stop - r_start) / 2)
    cluster_func(adata,resolution=r_mid)
    n_mid = len(np.unique(adata.obs[method]))
    if n_mid == n_clusters:
        return adata.obs[method].astype(int)
    print(r_start, r_stop, n_mid)
    if n_mid < n_clusters:
        return find_n_clusters(adata, n_clusters, method=method, r_start=r_mid, r_stop=r_stop)
    else:
        return find_n_clusters(adata, n_clusters, method=method, r_start=r_start, r_stop=r_mid)