tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
56 stars 13 forks source link

Cut up root node (not just ultimate ancestor) #850

Open hyanwong opened 1 year ago

hyanwong commented 1 year ago

On the basis that the ultimate ancestor is not biologically very plausible, in recent version of tsinfer we now cut up edges that led direct to the ultimate ancestor, by running the new post_process routine.

However, I suspect (and tests show) that we still make root ancestors that are too long. Therefore we could think about cutting up not just the ultimate ancestor, but also any root in which the edges-in or the edges-out change.

Here's some example code, with a histogram of actual edge spans of the root node. Note that this code may result in nodes that are not ordered strictly by time.

import collections
import itertools

import numpy as np
from matplotlib import pyplot as plt
import msprime
import tsinfer
import tskit

ts = msprime.sim_ancestry(100, population_size=1e4, recombination_rate=1e-8, sequence_length=1e7, random_seed=1)
print("Simulation has", ts.num_trees, "trees")
mts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
print("Simulation has", mts.num_mutations, "mutations")

sd = tsinfer.SampleData.from_tree_sequence(mts)
its = tsinfer.infer(sd, progress_monitor=True)
ists = its.simplify()  # remove unary nodes

def unsquash(edge_table, positions, edges=None):
    """
    For a set of positions and a given set of edges (or all edges if ``edges`` is None),
    create a set of new edges which cover the same span but which are chopped up into
    separate edges at every specified position. This is essentially the opposite of
    EdgeTable.squash()
    """
    new_edges = tskit.EdgeTable()
    positions = np.unique(positions) # sort and uniquify
    skip = []
    if edges is not None:
        skip = np.ones(edge_table.num_rows, dtype=bool)
        skip[edges] = False
    for edge, do_skip in itertools.zip_longest(edge_table, skip, fillvalue=False):
        if do_skip:
            new_edges.append(edge)
            continue
        for l, r in itertools.pairwise(itertools.chain(
            [edge.left],
            positions[np.logical_and(positions > edge.left, positions < edge.right)],
            [edge.right]
        )):
            new_edges.append(edge.replace(left=l, right=r))
    edge_table.replace_with(new_edges)

def break_root_nodes(ts):
    tables = ts.dump_tables()
    edges_to_break = set()
    # break up the edges to the root
    for tree in ts.trees():
        if tree.num_edges == 0:
            continue
        for u in tree.children(tree.root):
            edges_to_break.add(tree.edge(u))
    unsquash(
        tables.edges,
        ts.breakpoints(as_array=True),
        edges=np.array([e for e in edges_to_break]),
    )

    ts_split = tables.tree_sequence()
    tables.edges.clear()
    tables.mutations.clear()
    prev_root = None
    nd_map = {u: u for u in range(ts.num_nodes)}
    for ed, ed_split, tree in zip(
        ts.edge_diffs(), ts_split.edge_diffs(), ts_split.trees()
    ):
        if tree.num_edges == 0:
            continue
        if tree.root == prev_root:
            parents = {e.parent for e in ed.edges_out} | {e.parent for e in ed.edges_in}
            if tree.root in parents:
                nd_map[tree.root] = tables.nodes.append(ts.node(tree.root))
        for m in tree.mutations():
            tables.mutations.append(m.replace(node=nd_map[m.node]))
        prev_root = tree.root
        for e in ed_split.edges_in:
            tables.edges.add_row(
                left=e.left, right=e.right, parent=nd_map[e.parent], child=nd_map[e.child])
    tables.sort()
    tables.edges.squash()
    tables.sort()
    return tables.tree_sequence()

iists = break_root_nodes(ists)
print("Created",iists.num_nodes - ists.num_nodes, "new roots")

## Do some histograms

prev_root = None
root_breaks = [0]
for tree in mts.trees():
    if prev_root != tree.root:
        if prev_root is not None:
            root_breaks.append(tree.interval.left)
    prev_root = tree.root
root_breaks.append(mts.sequence_length)
plt.hist(np.log(np.diff(root_breaks)), bins=40, density=True, label="True")

r2 = [0]
prev_root = None
for tree in ists.trees():
    if tree.num_edges == 0:
        continue
    if prev_root != tree.root:
        r2.append(tree.interval.left)
    prev_root = tree.root
r2.append(ists.sequence_length)

plt.hist(np.log(np.diff(r2)), alpha=0.5, bins=40, density=True, label="split ultimate")

r3 = [0]
prev_root = None
for tree in iists.trees():
    if tree.num_edges == 0:
        continue
    if prev_root != tree.root:
        r3.append(tree.interval.left)
    prev_root = tree.root
r3.append(iists.sequence_length)

plt.hist(np.log(np.diff(r3)), alpha=0.5, bins=40, density=True, label="additionally split root")

plt.legend();
Simulation has 21109 trees
Simulation has 23596 mutations
Created 1651 new roots

image

hyanwong commented 1 year ago

And here are the correlations between the known lengths of root nodes and what we infer (it's a pretty poor correlation, though!)

rb = np.array(root_breaks)
mid_root_pos = rb[:-1] + np.diff(rb)/2
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1])

rb = np.array(r2)
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1], alpha=0.1)
print(
    "corr coeff: known root lengths vs lengths with split ultimate:\n ",
    np.corrcoef(np.diff(root_breaks), rb[ss] - rb[ss-1])[0, 1])

rb = np.array(r3)
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1], alpha=0.1)
print(
    "corr coeff: known root lengths vs lengths with extra split root:\n ",
    np.corrcoef(np.diff(root_breaks), rb[ss] - rb[ss-1])[0, 1])

plt.xscale('log')
plt.yscale('log')
corr coeff: known root lengths vs lengths with split ultimate:
  0.06516027384592456
corr coeff: known root lengths vs lengths with extra split root:
  0.13137918764309806

image

hyanwong commented 1 year ago

Extra splitting of the root certainly improves the n=10 plot from @a-ignatieva's ppreprint, especially when combined with @nspope's variational gamma method:

image

hyanwong commented 1 year ago

And here for 100 samples. Since these use exactly the same topology, the improvement can't be anything to do with e.g. better polytomy breaking.

image

hyanwong commented 1 year ago

@jeromekelleher and I decided this should be implemented at a minimum for post_process, and then probably rolled out as the default. However, it would be good to think of a more efficient method that the one coded above, and also a method that keeps the nodes in time-order (this might have to be done with a sort at the end, though)

hyanwong commented 10 months ago

A more justified model-based method to cutting up the root nodes is to implement the PSMC-on-the-tree idea for the root. If this is implemented, then it's possible that we should use that to cut up the root nodes instead. So there's an argument for making the version above only available as a non-default post-process option.