omegahh / DeepHiC

A GAN-based method for Enhancing Hi-C data
MIT License
28 stars 8 forks source link

How to detected the TAD boundaries using insulation score #12

Closed caragraduate closed 5 months ago

caragraduate commented 5 months ago

Hi there,

Thank you for this great tool! I have a quick question; in the paper, it is mentioned that you applied an insulation score to detect TAD boundaries; could you share the method how to run it exactly based on the output from your model? I could not find the code or script on the GitHub page.

Thanks a lot for the help!

omegahh commented 5 months ago
import os, sys
import bisect
import time, argparse
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from scipy.stats import zscore
from scipy.sparse import coo_matrix

from all_configs import *
from utils.io import compactM

chr_number = lambda x: x[x.find('chr')+3:x.rfind('_')]

strip_nan = lambda x: x[np.where(np.isnan(x)^True)[0]]

def zscore_hic(mat):
    mat_coo = coo_matrix(mat)
    mat_data_zs = zscore(mat_coo.data)
    mat = coo_matrix((mat_data_zs, (mat_coo.row, mat_coo.col)), shape=mat.shape).toarray()
    return mat

def boundary_intensity(mat, chunk):
    h, w = mat.shape
    intensity = np.zeros(h) + np.nan
    for i in range(chunk, h-chunk):
        a = mat[i-chunk:i, i-chunk:i]
        b = mat[i:i+chunk, i:i+chunk]
        c = mat[i-chunk:i-1, i+1:i+chunk]
        A = np.sum(np.triu(a, 1))
        B = np.sum(np.triu(b, 1))
        C = np.sum(c)
        intensity[i] = A+B-C
    return intensity

def insulation_score(mat, offset, chunk):
    h, w = mat.shape
    scores = np.zeros(h) + np.nan
    for i in range(offset+chunk-1, h-offset-chunk):
        lb_i, lb_j = i-offset, i+offset # left bottom
        lt_i, lt_j = lb_i-chunk+1, lb_j # left top point (aka. startpoint)
        rb_i, rb_j = lt_i+chunk, lt_j+chunk
        scores[i] = np.sum(mat[lt_i:rb_i, lt_j:rb_j])
    return scores

def score_delta(scores, width):
    delta = np.zeros_like(scores) + np.nan
    for i in range(width, len(scores)-width):
        delta[i] = np.mean(scores[i+1:i+width]) - np.mean(scores[i-width:i])
    return delta

def zero_points(arr, orient='descent'):
    points = []
    if orient == 'descent':
        for i in range(len(arr)-1):
            if arr[i]>0 and arr[i+1]<0:
                zero_p = i if (np.abs(arr[i])<np.abs(arr[i+1])) else i+1
                points.append(zero_p)
    if orient == 'ascent':
        for i in range(len(arr)-1):
            if arr[i]<0 and arr[i+1]>0:
                zero_p = i if (np.abs(arr[i])<np.abs(arr[i+1])) else i+1
                points.append(zero_p)
    return np.array(points)

def min_dist(r1_ind, r2_ind):
    # compare r2 with r1
    dists = [np.min(np.abs(r1_ind - i)) for i in r2_ind]
    return np.array(dists)

def count_overlap(base_ind, comp_ind):
    result = []
    for i in range(len(comp_ind)-1):
        comp_left, comp_right = comp_ind[i], comp_ind[i+1]
        loc_left = bisect.bisect_left(base_ind, comp_left)
        loc_right = bisect.bisect_left(base_ind, comp_right)
        cross_num = loc_right - loc_left
        if loc_right==len(base_ind): loc_right = loc_right - 1
        if loc_left>0: loc_left = loc_left - 1
        max_jaccard = 0.
        for loc in range(loc_left, loc_right):
            base_left, base_right = base_ind[loc], base_ind[loc+1]
            intervals = sorted([comp_left, comp_right, base_left, base_right])
            jaccard_index = (intervals[2]-intervals[1])/(intervals[3]-intervals[0])
            max_jaccard = np.max([max_jaccard, jaccard_index])
        result.append([cross_num, max_jaccard])
        if loc_right==len(base_ind): break
    return np.array(result)

def tad_caller(file, key, out_dir, bi_size=17, offset=2, is_size=10, delta_width=5, s=1400, e=1600):
    chrn = chr_num_str(os.path.basename(file))
    data = np.load(file)
    print(f'Reading {file}')
    hic = compactM(data[key], data['compact'])
    hic = zscore_hic(hic)
    bdry_inten = boundary_intensity(hic, bi_size)
    bdry_inten_delta = score_delta(bdry_inten, delta_width)
    bdry_inten_zeros = zero_points(bdry_inten_delta, orient='descent')

    insu_score = insulation_score(hic, offset, is_size)
    insu_score_delta = score_delta(insu_score, delta_width)
    insu_score_zeros = zero_points(insu_score_delta, orient='ascent')

    # start plotting
    fig = plt.figure(figsize = [9, 3])
    ax = fig.add_subplot(1, 2, 1)
    ax.plot(range(s,e), bdry_inten[s:e], label='Boundary Intensity')
    ax.plot(range(s,e), insu_score[s:e], label='Insulation Scores')
    ax.set(title=f'TAD analysis in [chr{chrn}]', xlabel='Genomic location (x10kb)', ylabel='scores')
    ax.legend(loc='lower left')
    ax = fig.add_subplot(1, 2, 2)
    ax.plot(range(s,e), bdry_inten_delta[s:e], label='Boundary Intensity')
    ax.plot(range(s,e), insu_score_delta[s:e], label='Insulation Scores')
    ax.plot(range(s, e), np.zeros(e-s), color='gray', linestyle='--')
    ax.set(title=f'TAD analysis in [chr{chrn}]', xlabel='Genomic location (x10kb)', ylabel='delta of scores')
    ax.legend(loc='lower left')
    plt.tight_layout()
    print(f"Plot to {os.path.join(out_dir, f'tad_analysis_{chrn}.png')}")
    fig.savefig(os.path.join(out_dir, f'tad_analysis_{chrn}.png'), format='png', dpi=100)
    # saving data
    all_result = np.column_stack((bdry_inten, bdry_inten_delta, insu_score, insu_score_delta))
    np.savetxt(os.path.join(out_dir, f'results_chr{chrn}.csv'), all_result, fmt='%.6f', delimiter='\t', header='boundary_intensity\tboundary_intensity_delta\tinsulation_score\tinsulation_score_delta')
    np.savetxt(os.path.join(out_dir, f'bi_zeropoints_chr{chrn}.csv'), bdry_inten_zeros.T, fmt='%d', delimiter='\t', header='bi_zeropoints')
    np.savetxt(os.path.join(out_dir, f'is_zeropoints_chr{chrn}.csv'), insu_score_zeros.T, fmt='%d', delimiter='\t', header='is_zeropoints')

def tad_comparison(base_dir, cmpd_dir, chrn, out_dir):
    base_bi_zeros = np.loadtxt(os.path.join(base_dir, f'bi_zeropoints_chr{chrn}.csv'))
    cmpd_bi_zeros = np.loadtxt(os.path.join(cmpd_dir, f'bi_zeropoints_chr{chrn}.csv'))

    bi_dists = min_dist(base_bi_zeros, cmpd_bi_zeros)
    bi_jacard = count_overlap(base_bi_zeros, cmpd_bi_zeros)

    base_is_zeros = np.loadtxt(os.path.join(base_dir, f'is_zeropoints_chr{chrn}.csv'))
    cmpd_is_zeros = np.loadtxt(os.path.join(cmpd_dir, f'is_zeropoints_chr{chrn}.csv'))

    is_dists = min_dist(base_is_zeros, cmpd_is_zeros)
    is_jacard = count_overlap(base_is_zeros, cmpd_is_zeros)

     # start plotting
    fig = plt.figure(figsize = [9, 3])
    ax = fig.add_subplot(1, 2, 1)
    ax.hist([bi_dists, is_dists], label=['Boundary Intensity', 'Insulation Scores'])
    ax.set(title=f'TAD comparison in [chr{chrn}]', xlabel='Genomic location (x10kb)', ylabel='frequency')
    ax.legend(loc='lower left')
    ax = fig.add_subplot(1, 2, 2)
    ax.hist([bi_jacard[:,1], is_jacard[:,1]], label=['Boundary Intensity', 'Insulation Scores'])
    ax.set(title=f'TAD comparison in [chr{chrn}]', xlabel='Genomic location (x10kb)', ylabel='jaccard index')
    ax.legend(loc='lower left')
    plt.tight_layout()
    print(f"Plot to {os.path.join(out_dir, f'tad_comparison_{chrn}.png')}")
    fig.savefig(os.path.join(out_dir, f'tad_comarison_{chrn}.png'), format='png', dpi=100)
    np.savetxt(os.path.join(out_dir, f'cmpd_bi_dists_chr{chrn}.csv'), bi_dists.T, fmt='%d', delimiter='\t', header='distance')
    np.savetxt(os.path.join(out_dir, f'cmpd_is_dists_chr{chrn}.csv'), is_dists.T, fmt='%d', delimiter='\t', header='distance')
    np.savetxt(os.path.join(out_dir, f'cmpd_bi_jaccard_chr{chrn}.csv'), bi_jacard, fmt='%.6f', delimiter='\t', header='class_num\tjaccard_index')
    np.savetxt(os.path.join(out_dir, f'cmpd_is_jaccard_chr{chrn}.csv'), is_jacard, fmt='%.6f', delimiter='\t', header='class_num\tjaccard_index')

if __name__ == '__main__':
    pool_num = 23 if multiprocessing.cpu_count() > 23 else multiprocessing.cpu_count()

    parser = argparse.ArgumentParser(description='Arguments for Visulize TAD Analysis')
    parser.add_argument('-c', dest='cell_line', help='Cell line folder for input', required=True)
    parser.add_argument('-d', dest='dataset', help='REQUIRED: dataset specified[example:deephic]', 
                        default='10kb', required=True, choices=['10kb', '40kb', 'deephic', 'hicplus', 'hicnn', 'boost'])

    args = parser.parse_args(sys.argv[1:])

    cell_line = args.cell_line
    dataset = args.dataset

    input_dir = os.path.join(INPUT_DIR[dataset], cell_line)

    if dataset == '10kb':
        files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.find('10kb') >= 0]
    else:
        files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.find('40kb') >= 0]

    output_dir = os.path.join(RESULT_DIR[dataset], cell_line, 'tad_analysis')
    mkdir(output_dir)

    key = 'hic' if (dataset=='10kb' or dataset=='40kb') else dataset

    start = time.time()
    pool = multiprocessing.Pool(processes=pool_num)
    print(f'Start a multiprocess pool with process_num = {pool_num} for calculating TAD boundaries')
    kwargs = {'bi_size':17, 'offset':2, 'is_size':10, 'delta_width':5, 's':2000, 'e':2270}
    for file in files:
        pool.apply_async(tad_caller, (file, key, output_dir,), kwargs)
    pool.close()
    pool.join()
    print(f'All TAD scores calculated. Running cost is {(time.time()-start)/60:.1f} min')

    if dataset != '10kb':
        print('Also perform comparsions....')
        chr_list = set_dict['all']

        orig_file_dir = os.path.join(RESULT_DIR['10kb'], cell_line, 'tad_analysis')
        cmpd_file_dir = os.path.join(RESULT_DIR[dataset], cell_line, 'tad_analysis')

        start = time.time()
        pool = multiprocessing.Pool(processes=pool_num)
        print(f'Start a multiprocess pool with process_num = {pool_num} for comparing TAD boundaries')
        for chrn in chr_list:
            pool.apply_async(tad_comparison, (orig_file_dir, cmpd_file_dir, chrn, output_dir,))
        pool.close()
        pool.join()
        print(f'All TAD boundaries compared. Running cost is {(time.time()-start)/60:.1f} min')

Hi, these are codes I used for analyzing before. Hope it can help you