tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
56 stars 13 forks source link

Plot "edge-node time consistency" #953

Open hyanwong opened 2 months ago

hyanwong commented 2 months ago

For improving inference accuracy, including decent ancestor reconstruction, we don't really care about the absolute times of the nodes under each mutation. Rather, we want to know that the "local" node order is correct. In fact, the only thing we really want is for the parent node and child node of each edge in the true tree sequence to be in the right order in the inferred tree sequence.

Therefore, to compare the accuracy of our inference (in order to improve it) we can match nodes between true and inferred tree sequence (perhaps on the basis of the node under each mutation?), then take each edge in the true tree sequence and ask if the inferred (and dated) tree sequence has the equivalent nodes in the right order (i.e. parent > child time). This should give us something to aim for re improving inference.

I'll figure out some plots to show improvement in this stat, but meanwhile I think this is a reasonable way to test:

import tsinfer
import tskit
import tsdate
import numpy as np
import msprime

# Simulate
sim_ts = msprime.sim_ancestry(50, sequence_length=1e6, population_size=1e4, recombination_rate=1e-8, random_seed=1)
sim_ts = msprime.sim_mutations(sim_ts, rate=1e-8, random_seed=1)
print("Simulated", sim_ts.num_sites, "sites", sim_ts.num_trees, "trees")

# Infer
use_sites_time = False
info = "With true site times " if use_sites_time else ""
ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(
    sim_ts,
    use_sites_time=use_sites_time,
))

# Date
pts = tsdate.preprocess_ts(ts)
dts = tsdate.date(pts, mutation_rate=1e-8)

def edge_node_compat(orig_ts, new_ts):
    # map the node in the original to a node in the new one, if possible
    corresponding_nodes = np.full(orig_ts.num_nodes, -1, dtype=orig_ts.edges_child.dtype)
    # find mutations below each node
    assert new_ts.num_sites == orig_ts.num_sites
    for new_site, orig_site in zip(new_ts.sites(), orig_ts.sites()):
        if len(new_site.mutations) and len(orig_site.mutations):
            # first mutation is always oldest, by tskit definition
            corresponding_nodes[orig_site.mutations[0].node] = new_site.mutations[0].node

    unique_child_parent = np.unique(np.array([orig_ts.edges_child, orig_ts.edges_parent]).T, axis=0)
    nodes_time = np.concatenate((new_ts.nodes_time, [-1]))
    child_times_in_new = nodes_time[corresponding_nodes[orig_ts.edges_child[unique_child_parent[:, 0]]]]
    parent_times_in_new = nodes_time[corresponding_nodes[orig_ts.edges_parent[unique_child_parent[:, 1]]]]
    used = np.logical_and(child_times_in_new >= 0, parent_times_in_new >= 0)
    compat = child_times_in_new[used] < parent_times_in_new[used]
    return compat, used

good, use = edge_node_compat(sim_ts, ts)
print(f"{info}{sum(good) / len(good) * 100:.2f}% ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation for comparison)")

good, use = edge_node_compat(sim_ts, pts)
print(f"After preprocessing, {sum(good) / len(good) * 100:.2f}%  ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation)")

good, use = edge_node_compat(sim_ts, dts)
print(f"After dating, {sum(good) / len(good) * 100:.2f}%  ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation)")

Giving (in this simplest simulation example):

Simulated 2047 sites 1758 trees
86.53% (2236) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation for comparison)
After preprocessing, 87.85%  (2270) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)
After dating, 95.82%  (2476) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)

Weirdly, setting use_sites_time=True gives somewhat under 100%. I'm not sure why this is. When we infer and date, we do worse than using the true times.

Simulated 2047 sites 1758 trees
With true site times 98.92% (2556) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation for comparison)
After preprocessing, 99.69%  (2576) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)
After dating, 96.56%  (2495) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)