petrelharp / num_edges

0 stars 0 forks source link

Keeping detached unary segments on simplify #2

Open hyanwong opened 2 years ago

hyanwong commented 2 years ago

In https://github.com/tskit-dev/tskit/pull/2381 @petrelharp pointed out that there are potentially 2 different ways of simplifying while keeping nodes which are partially coalescent and partially unary. The first way is to keep the node everywhere that it is present in the genealogy. The second is to only keep the node in regions where it is contiguous with a coalescent section. There are arguments for both versions. Here is a simulated example:

import collections

import numpy as np
import msprime

num_samples = 8
ts  = msprime.sim_ancestry(
    num_samples,
    ploidy=1,
    sequence_length=1e8,
    recombination_rate=1e-8,
    record_full_arg=True,
    random_seed=54321
)

def node_spans_max_children(ts):
    node_spans = collections.defaultdict(list)
    # node_id => [(left, right, [n_children1, n_children2,...]), ()]
    curr_parents = collections.defaultdict(set)

    for tree, diffs in zip(ts.trees(), ts.edge_diffs()):
        for e_in in diffs.edges_in:
            u = e_in.parent
            if len(curr_parents[u]) == 0:
                # node starts
                node_spans[u].append(
                    [diffs.interval.left, diffs.interval.right, tree.num_children(u)]
                )
            else:
                node_spans[u][-1][2] = max(node_spans[u][-1][2], tree.num_children(u))
            curr_parents[e_in.parent].add(e_in.id)
        for e_out in diffs.edges_out:
            u = e_out.parent
            curr_parents[u].remove(e_out.id)
            if len(curr_parents[u]) == 0:
                # node ends
                node_spans[u][-1][1] = diffs.interval.right

    for u, data in node_spans.items():
        max_children = max(contiguous[2] for contiguous in data)
        for contiguous in data:
            if max_children > 1 and contiguous[2] < 2:
                print("Node", u, "is contiguously unary from", contiguous)

node_spans_max_children(ts)
7 sites
Node 29 is contiguously unary from [0.0, 24318676.0, 1]
Node 29 is contiguously unary from [57049008.0, 57769219.0, 1]
Node 45 is contiguously unary from [80875054.0, 88999138.0, 1]
Node 46 is contiguously unary from [62435555.0, 80875054.0, 1]
Node 46 is contiguously unary from [80875054.0, 88999138.0, 1]
Node 51 is contiguously unary from [62435555.0, 80875054.0, 1]
# plot out some examples
ts.draw_svg(time_scale="rank", style="""
    .node .sym, .node .lab {display:none}
    .n29 > .sym, .n29 > .lab, .n45 > .sym, .n45 > .lab,
    .n46 > .sym, .n46 > .lab,.n51 > .sym, .n51 > .lab {display: initial}
    .n29 > .sym {fill: red}
    .n45 > .sym {fill: yellow}
    .n46 > .sym {fill: magenta}
    .n51 > .sym {fill: cyan}
    """)

unknown

hyanwong commented 2 years ago

@petrelharp wrote, in reference to a similar example from tsinferred trees:

If the unconnected, unary-only segments of a coalescent nodes are also inferrable, then we should also include them! But, I am not sure that they are? Do you have a good reason that they might be? So, one question is: are these unary-only spans real or artifactual? Like, when tsinfer infers a unary-only bit, how often is that unary-only bit "correct"?

I suspect that some of these are reasonable inferences, although I'm not sure how to prove that. a "detached segment" can occur when recombination in more recent time breaks up a previously contiguous node. If we have correctly constructed the ancestral haplotype (which we might be able to do even if there has been a little recombination since then), then we can probably infer its presence in unconnected regions.

jeromekelleher commented 2 years ago

It depends on how well you've inferred the ancestral haplotypes. If you have correctly inferred the haplotypes corresponding to all coalescences, then certainly these disconnected unary regions can and will be correctly identified.

Whether it's possible to infer haplotypes like this is a separate (and interesting) questions.

hyanwong commented 2 years ago

Combining the function above with the one in #1, here some code to plot this out (logic not entirely tested though, and doesn't quite look right. I will check later) Corrected for the "bug" in #1

import collections

import numpy as np
import msprime
import tskit
import tsinfer

num_samples = 8
ts  = msprime.sim_ancestry(
    num_samples,
    ploidy=1,
    sequence_length=1e10,
    recombination_rate=1e-8,
    record_full_arg=True,
    random_seed=123
)

ts_orig = msprime.sim_mutations(ts, rate=5e-8, random_seed=123)
print(ts_orig.num_sites, "sites")
ts_inferred = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(ts_orig))
print("Inferred", ts_inferred.num_trees, "trees")

def node_spans_max_children(ts):
    node_spans = collections.defaultdict(list)
    # node_id => [(left, right, [n_children1, n_children2,...]), ()]
    curr_parents = collections.defaultdict(set)

    for tree, diffs in zip(ts.trees(), ts.edge_diffs()):
        for e_in in diffs.edges_in:
            u = e_in.parent
            if len(curr_parents[u]) == 0:
                # node starts
                node_spans[u].append(
                    [diffs.interval.left, diffs.interval.right, tree.num_children(u)]
                )
            else:
                node_spans[u][-1][2] = max(node_spans[u][-1][2], tree.num_children(u))
            curr_parents[e_in.parent].add(e_in.id)
        for e_out in diffs.edges_out:
            u = e_out.parent
            curr_parents[u].remove(e_out.id)
            if len(curr_parents[u]) == 0:
                # node ends
                node_spans[u][-1][1] = diffs.interval.right

    return node_spans

def compare_ancestors(ts_orig, ts_cmp):
    assert ts_orig.num_sites == ts_cmp.num_sites
    # Simplify so that the nodes for comparison are never unary
    ts_orig, orig_node_map = ts_orig.simplify(map_nodes=True)
    ts_cmp, cmp_node_map = ts_cmp.simplify(map_nodes=True)
    # we need to map new IDs -> old ones, so reverse the map
    orig_node_map = {j: i for i, j in enumerate(orig_node_map) if j >= 0}
    cmp_node_map = {j: i for i, j in enumerate(cmp_node_map) if j >= 0}

    comparable_nodes = []
    site_id = 0
    for interval, t_orig, t_cmp in ts_orig.coiterate(ts_cmp):
        while ts_orig.site(site_id).position < interval.right:
            site_orig = ts_orig.site(site_id)
            site_cmp = ts_cmp.site(site_id)
            assert site_orig.position == site_cmp.position
            assert alleles(site_orig) == alleles(site_cmp)
            # Only compare cases where there is one mutation which is not a singleton
            if len(site_orig.mutations) == 1 and t_orig.num_samples(site_orig.mutations[0].node) > 1:
                oldest_mutation_orig = oldest_mutation_node(site_orig, ts_orig)
                oldest_mutation_cmp = oldest_mutation_node(site_cmp, ts_cmp)
                node_from_orig = oldest_mutation_orig.node
                node_from_cmp = oldest_mutation_cmp.node
                assert t_orig.num_children(node_from_orig) > 1
                assert t_cmp.num_children(node_from_cmp) > 1
                comparable_nodes.append([
                    orig_node_map[node_from_orig],
                    cmp_node_map[node_from_cmp],
                    site_orig.position,
                ])
            site_id += 1
            if site_id >= ts_orig.num_sites:
                return comparable_nodes
    raise ValueError("did not inspect all sites") 

def alleles(site):
    alleles = set((m.derived_state for m in site.mutations))
    alleles.add(site.ancestral_state)
    return alleles

def oldest_mutation_node(site, ts_in):
    return max(site.mutations, key=lambda m: ts_in.node(m.node).time)

orig_spans = node_spans_max_children(ts_orig)
inferred_spans = node_spans_max_children(ts_inferred)
corresponding_nodes = compare_ancestors(ts_orig, ts_inferred)

results = []
processed = set()
for node_orig, node_inferred, pos in corresponding_nodes:
    if (node_orig, node_inferred) in processed:
        continue
    processed.add((node_orig, node_inferred))
    max_children_orig = max(contiguous[2] for contiguous in orig_spans[node_orig])
    min_children_orig = min(contiguous[2] for contiguous in orig_spans[node_orig])

    max_children_inferred = max(contiguous[2] for contiguous in inferred_spans[node_inferred])
    min_children_inferred = min(contiguous[2] for contiguous in inferred_spans[node_inferred])

    if max_children_orig > 1 and min_children_orig == 1 and max_children_inferred > 1 and min_children_inferred == 1:
        results.append({
            "orig": [node_orig, orig_spans[node_orig]],
            "inferred": [node_inferred, inferred_spans[node_inferred]]
        })

# Now plot

import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=len(results), figsize=(20, 30))
axes[0].set_title("Node spans (red=contiguous span is unary throughout)")
for ax, result in zip(axes, results):
    ax.set_ylim(0, 3)
    orig_node_id, spans = result["orig"]
    ax.hlines(
        [2]*len(spans),
        [s[0] for s in spans],
        [s[1] for s in spans],
        colors=["red" if s[2]<2 else "blue" for s in spans])
    inferred_node_id, spans = result["inferred"]
    ax.hlines(
        [1]*len(spans),
        [s[0] for s in spans],
        [s[1] for s in spans],
        colors=["red" if s[2]<2 else "blue" for s in spans])
    ax.set_yticks([1, 2])
    ax.set_yticklabels([f"Inferred node {inferred_node_id}", f"Original node {orig_node_id}"])

Unknown

petrelharp commented 2 years ago

Hm - tsinfer doesn't seem to be doing a great job by this metric - perhaps we should turn up the mutation rate so it's an easier problem?

And, to make sure I've got this right - is what your code doing (a) identifying comparable pairs of node as the nodes on which each mutation occurs (taking the oldest mutation in the case of multiple hits), and (b) then for each such pair plotting where along the genome they are in the tree at all, colored by whether they are unary or not? (Edit: comparable nodes are the first coalescent node in each below each mutation.)

hyanwong commented 2 years ago

Hm - tsinfer doesn't seem to be doing a great job by this metric - perhaps we should turn up the mutation rate so it's an easier problem?

True, turning the mutation rate up might be useful. One problem here is that there is not a 1-1 mapping: e.g. many of the inferred nodes correspond to multiple different actual nodes (because we are going on topology, and often in the simulation the topology stays the same but the node lights and identities change.

And, to make sure I've got this right - is what your code doing (a) identifying comparable pairs of node as the nodes on which each mutation occurs (taking the oldest mutation in the case of multiple hits),

Actually, the first non-unary node below each mutation (i.e. the node that groups all the samples with the derived mutation). I think this is the right way to do it, but I'm not 100% sure.

and (b) then for each such pair plotting where along the genome they are in the tree at all, colored by whether they are unary or not? (Edit: comparable nodes are the first coalescent node in each below each mutation.)

Yes (I just read the edit!). Also I haven't actually checked the code for correctness, it's just a demo.