tskit-dev / tskit

Population-scale genomics
MIT License
153 stars 72 forks source link

Simpler, more efficient interface for pair coalescence rates #2904

Open nspope opened 7 months ago

nspope commented 7 months ago

I'm thinking that the CoalescenceTimeDistribution class would be best moved to a separate package (where it can be numba-fied), and that instead we'll want to put an efficient algorithm for calculating pair coalescence rates within time windows into tskit.

The only computationally intensive part of these statistics -- that would need to go into the C library -- is calculating the number of pairs that coalesce in a given node, where each pair-node combination is weighted by its span along the sequence.

Here's a particularly simple edge-based algorithm for doing this, that could be easily extended to include cross-coalescence between sample subsets, windows along the genome, etc:

def pair_coalescence_counts(ts):
    """
    Calculate weighted # of pair coalescence events per node using incremental updates across trees
    """

    edges_child = ts.edges_child
    edges_parent = ts.edges_parent

    nodes_parent = np.full(ts.num_nodes, tskit.NULL)
    sample_counts = np.zeros(ts.num_nodes)
    coalescing_pairs = np.zeros(ts.num_nodes)

    for c in ts.samples():
        sample_counts[c] = 1

    for ed in ts.edge_diffs():
        edges_in = [e.id for e in ed.edges_in]
        edges_out = [e.id for e in ed.edges_out]

        right_span = ts.sequence_length - ed.interval.left

        for e in edges_out:
            c, p = edges_child[e], edges_parent[e]
            nodes_parent[c] = tskit.NULL
            update = sample_counts[c]
            while p != tskit.NULL:
                coalescing_pairs[p] -= update * (sample_counts[p] - sample_counts[c]) * right_span
                c, p = p, nodes_parent[p]
            p = edges_parent[e]
            while p != tskit.NULL:
                sample_counts[p] -= update
                p = nodes_parent[p]

        for e in edges_in:
            c, p = edges_child[e], edges_parent[e]
            nodes_parent[c] = p
            update = sample_counts[c]
            while p != tskit.NULL:
                sample_counts[p] += update
                p = nodes_parent[p]
            p = edges_parent[e]
            while p != tskit.NULL:
                coalescing_pairs[p] += update * (sample_counts[p] - sample_counts[c]) * right_span
                c, p = p, nodes_parent[p]

    return coalescing_pairs

A crude correctness check is to verify the relationship between coalescing pair counts and branch diversity,

ts = msprime.sim_ancestry(
    samples=1000, ploidy=1, population_size=1e4, 
    recombination_rate=1e-8, sequence_length=1e6, random_seed=1
)

w = pair_coalescence_counts(ts)
print(
    2 * np.sum(ts.nodes_time * w) / np.sum(w),
    ts.diversity(mode='branch'),
)
# 21577.957500882745 21577.957500882756
nspope commented 7 months ago

Loosely, I'm thinking that we'd have a method,

ts.pairwise_coalescences(sample_sets, indexes=None, windows=None, ...)

for calculating cross-coalescence, that mirrors ts.divergence in how indexes is used to specify pairs of sample sets.

Given the output of this method, it's a pretty trivial calculation to get average coalescence rates within time windows. We could add a second method for this purpose:

ts.pairwise_coalescence_rates(sample_sets, indexes=None, time_windows=None, windows=None, ...)

that invokes ts.pairwise_coalescences under the hood.

jeromekelleher commented 7 months ago

I'm 100% behind this idea, and the algorithm looks very simple and elegant. I think CoalescenceTimeDistribution is sufficiently meaty for it's own package (which is then free to use numba etc)

petrelharp commented 7 months ago

Sounds like a good plan to me! But, I guess the downside to making a new package is that that's additional infrastructure we need to keep updated (e.g., another .github/actions, another documentation website, etcetera).

jeromekelleher commented 7 months ago

Fair point

nspope commented 7 months ago

Here's a generalization of the algorithm above for genomic windows / sample sets; that supports cross-coalescence between sample sets:


def pair_coalescence_counts(ts, sample_sets, indexes, windows):
    """
    Calculate average pair coalescence events per node using incremental updates across trees
    """

    num_windows = windows.size - 1
    num_sample_sets = len(sample_sets)
    num_indexes = len(indexes)

    window_left, window_right = windows[:-1], windows[1:]

    nodes_parent = np.full(ts.num_nodes, tskit.NULL)
    sample_counts = np.zeros((ts.num_nodes, num_sample_sets))
    coalescing_pairs = np.zeros((ts.num_nodes, num_indexes, num_windows))

    edges_child = ts.edges_child
    edges_parent = ts.edges_parent

    for i, s in enumerate(sample_sets): 
        sample_counts[s, i] = 1

    for ed in ts.edge_diffs():

        edges_in = [e.id for e in ed.edges_in]
        edges_out = [e.id for e in ed.edges_out]

        # TODO: "right_span" calculation is O(windows) per tree, but could be made sparse
        right_span = np.minimum(window_right, ts.sequence_length) - np.maximum(window_left, ed.interval.left)
        right_span = np.maximum(right_span, 0)

        for e in edges_out:
            c, p = edges_child[e], edges_parent[e]
            nodes_parent[c] = tskit.NULL
            within = sample_counts[c]
            while p != tskit.NULL:
                without = sample_counts[p] - sample_counts[c]
                for i, (j, k) in enumerate(indexes):
                    weight = within[j] * without[k] + within[k] * without[j]
                    coalescing_pairs[p, i] -= weight * right_span
                c, p = p, nodes_parent[p]
            p = edges_parent[e]
            while p != tskit.NULL:
                sample_counts[p] -= within
                p = nodes_parent[p]

        for e in edges_in:
            c, p = edges_child[e], edges_parent[e]
            nodes_parent[c] = p
            within = sample_counts[c]
            while p != tskit.NULL:
                sample_counts[p] += within
                p = nodes_parent[p]
            p = edges_parent[e]
            while p != tskit.NULL:
                without = sample_counts[p] - sample_counts[c]
                for i, (j, k) in enumerate(indexes):
                    weight = within[j] * without[k] + within[k] * without[j]
                    coalescing_pairs[p, i] += weight * right_span
                c, p = p, nodes_parent[p]

    for i, (j, k) in enumerate(indexes):
        coalescing_pairs[:, i, :] /= 1 + int(j == k)

    return coalescing_pairs
nspope commented 7 months ago

A reality check against branch diversity/divergence:

ts = msprime.sim_ancestry(samples=1000, ploidy=1, population_size=1e4, recombination_rate=1e-8, sequence_length=1e6, random_seed=1)
sample_sets = [np.arange(0, 10), np.arange(50, 100), np.arange(750, 862)]
indexes = [(0, 0), (2, 1), (1, 0)]
windows = np.array([0, 0.25e6, 0.75e6, 1e6])
w = pair_coalescence_counts(ts, sample_sets, indexes, windows)
for i, (j, k) in enumerate(indexes):
    div_coal = 2 * (w[:, i, :].T @ ts.nodes_time) / (w[:, i, :].T @ np.ones(ts.num_nodes))
    if j == k:
        div_stat = ts.diversity(sample_sets=[sample_sets[j]], mode='branch', windows=windows).flatten()
    else:
        div_stat = ts.divergence(sample_sets=[sample_sets[j], sample_sets[k]], mode='branch', windows=windows).flatten()
    print(f"coalrate: {div_coal}, truth: {div_stat}")

# coalrate: [23520.97296844 22287.16965288 18511.34622941], truth: [23520.97296844 22287.16965288 18511.34622941]
# coalrate: [25218.71328526 20906.13747972 17729.12254241], truth: [25218.71328526 20906.13747972 17729.12254241]
# coalrate: [24217.31835446 21827.22271182 17769.38565504], truth: [24217.31835446 21827.22271182 17769.38565504]