Open hyanwong opened 2 years ago
@petrelharp wrote, in reference to a similar example from tsinferred trees:
If the unconnected, unary-only segments of a coalescent nodes are also inferrable, then we should also include them! But, I am not sure that they are? Do you have a good reason that they might be? So, one question is: are these unary-only spans real or artifactual? Like, when tsinfer infers a unary-only bit, how often is that unary-only bit "correct"?
I suspect that some of these are reasonable inferences, although I'm not sure how to prove that. a "detached segment" can occur when recombination in more recent time breaks up a previously contiguous node. If we have correctly constructed the ancestral haplotype (which we might be able to do even if there has been a little recombination since then), then we can probably infer its presence in unconnected regions.
It depends on how well you've inferred the ancestral haplotypes. If you have correctly inferred the haplotypes corresponding to all coalescences, then certainly these disconnected unary regions can and will be correctly identified.
Whether it's possible to infer haplotypes like this is a separate (and interesting) questions.
Combining the function above with the one in #1, here some code to plot this out (logic not entirely tested though, and doesn't quite look right. I will check later) Corrected for the "bug" in #1
import collections
import numpy as np
import msprime
import tskit
import tsinfer
num_samples = 8
ts = msprime.sim_ancestry(
num_samples,
ploidy=1,
sequence_length=1e10,
recombination_rate=1e-8,
record_full_arg=True,
random_seed=123
)
ts_orig = msprime.sim_mutations(ts, rate=5e-8, random_seed=123)
print(ts_orig.num_sites, "sites")
ts_inferred = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(ts_orig))
print("Inferred", ts_inferred.num_trees, "trees")
def node_spans_max_children(ts):
node_spans = collections.defaultdict(list)
# node_id => [(left, right, [n_children1, n_children2,...]), ()]
curr_parents = collections.defaultdict(set)
for tree, diffs in zip(ts.trees(), ts.edge_diffs()):
for e_in in diffs.edges_in:
u = e_in.parent
if len(curr_parents[u]) == 0:
# node starts
node_spans[u].append(
[diffs.interval.left, diffs.interval.right, tree.num_children(u)]
)
else:
node_spans[u][-1][2] = max(node_spans[u][-1][2], tree.num_children(u))
curr_parents[e_in.parent].add(e_in.id)
for e_out in diffs.edges_out:
u = e_out.parent
curr_parents[u].remove(e_out.id)
if len(curr_parents[u]) == 0:
# node ends
node_spans[u][-1][1] = diffs.interval.right
return node_spans
def compare_ancestors(ts_orig, ts_cmp):
assert ts_orig.num_sites == ts_cmp.num_sites
# Simplify so that the nodes for comparison are never unary
ts_orig, orig_node_map = ts_orig.simplify(map_nodes=True)
ts_cmp, cmp_node_map = ts_cmp.simplify(map_nodes=True)
# we need to map new IDs -> old ones, so reverse the map
orig_node_map = {j: i for i, j in enumerate(orig_node_map) if j >= 0}
cmp_node_map = {j: i for i, j in enumerate(cmp_node_map) if j >= 0}
comparable_nodes = []
site_id = 0
for interval, t_orig, t_cmp in ts_orig.coiterate(ts_cmp):
while ts_orig.site(site_id).position < interval.right:
site_orig = ts_orig.site(site_id)
site_cmp = ts_cmp.site(site_id)
assert site_orig.position == site_cmp.position
assert alleles(site_orig) == alleles(site_cmp)
# Only compare cases where there is one mutation which is not a singleton
if len(site_orig.mutations) == 1 and t_orig.num_samples(site_orig.mutations[0].node) > 1:
oldest_mutation_orig = oldest_mutation_node(site_orig, ts_orig)
oldest_mutation_cmp = oldest_mutation_node(site_cmp, ts_cmp)
node_from_orig = oldest_mutation_orig.node
node_from_cmp = oldest_mutation_cmp.node
assert t_orig.num_children(node_from_orig) > 1
assert t_cmp.num_children(node_from_cmp) > 1
comparable_nodes.append([
orig_node_map[node_from_orig],
cmp_node_map[node_from_cmp],
site_orig.position,
])
site_id += 1
if site_id >= ts_orig.num_sites:
return comparable_nodes
raise ValueError("did not inspect all sites")
def alleles(site):
alleles = set((m.derived_state for m in site.mutations))
alleles.add(site.ancestral_state)
return alleles
def oldest_mutation_node(site, ts_in):
return max(site.mutations, key=lambda m: ts_in.node(m.node).time)
orig_spans = node_spans_max_children(ts_orig)
inferred_spans = node_spans_max_children(ts_inferred)
corresponding_nodes = compare_ancestors(ts_orig, ts_inferred)
results = []
processed = set()
for node_orig, node_inferred, pos in corresponding_nodes:
if (node_orig, node_inferred) in processed:
continue
processed.add((node_orig, node_inferred))
max_children_orig = max(contiguous[2] for contiguous in orig_spans[node_orig])
min_children_orig = min(contiguous[2] for contiguous in orig_spans[node_orig])
max_children_inferred = max(contiguous[2] for contiguous in inferred_spans[node_inferred])
min_children_inferred = min(contiguous[2] for contiguous in inferred_spans[node_inferred])
if max_children_orig > 1 and min_children_orig == 1 and max_children_inferred > 1 and min_children_inferred == 1:
results.append({
"orig": [node_orig, orig_spans[node_orig]],
"inferred": [node_inferred, inferred_spans[node_inferred]]
})
# Now plot
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=len(results), figsize=(20, 30))
axes[0].set_title("Node spans (red=contiguous span is unary throughout)")
for ax, result in zip(axes, results):
ax.set_ylim(0, 3)
orig_node_id, spans = result["orig"]
ax.hlines(
[2]*len(spans),
[s[0] for s in spans],
[s[1] for s in spans],
colors=["red" if s[2]<2 else "blue" for s in spans])
inferred_node_id, spans = result["inferred"]
ax.hlines(
[1]*len(spans),
[s[0] for s in spans],
[s[1] for s in spans],
colors=["red" if s[2]<2 else "blue" for s in spans])
ax.set_yticks([1, 2])
ax.set_yticklabels([f"Inferred node {inferred_node_id}", f"Original node {orig_node_id}"])
Hm - tsinfer doesn't seem to be doing a great job by this metric - perhaps we should turn up the mutation rate so it's an easier problem?
And, to make sure I've got this right - is what your code doing (a) identifying comparable pairs of node as the nodes on which each mutation occurs (taking the oldest mutation in the case of multiple hits), and (b) then for each such pair plotting where along the genome they are in the tree at all, colored by whether they are unary or not? (Edit: comparable nodes are the first coalescent node in each below each mutation.)
Hm - tsinfer doesn't seem to be doing a great job by this metric - perhaps we should turn up the mutation rate so it's an easier problem?
True, turning the mutation rate up might be useful. One problem here is that there is not a 1-1 mapping: e.g. many of the inferred nodes correspond to multiple different actual nodes (because we are going on topology, and often in the simulation the topology stays the same but the node lights and identities change.
And, to make sure I've got this right - is what your code doing (a) identifying comparable pairs of node as the nodes on which each mutation occurs (taking the oldest mutation in the case of multiple hits),
Actually, the first non-unary node below each mutation (i.e. the node that groups all the samples with the derived mutation). I think this is the right way to do it, but I'm not 100% sure.
and (b) then for each such pair plotting where along the genome they are in the tree at all, colored by whether they are unary or not? (Edit: comparable nodes are the first coalescent node in each below each mutation.)
Yes (I just read the edit!). Also I haven't actually checked the code for correctness, it's just a demo.
In https://github.com/tskit-dev/tskit/pull/2381 @petrelharp pointed out that there are potentially 2 different ways of simplifying while keeping nodes which are partially coalescent and partially unary. The first way is to keep the node everywhere that it is present in the genealogy. The second is to only keep the node in regions where it is contiguous with a coalescent section. There are arguments for both versions. Here is a simulated example: