tskit-dev / tskit

Population-scale genomics
MIT License
155 stars 73 forks source link

Randomly resolve polytomies #809

Closed hyanwong closed 3 years ago

hyanwong commented 4 years ago

I've finally got some working code to randomly resolve polytomies in a tree sequence. It seems to work for the admittedly small sample of inferred trees that I have tried. It tries to be clever by not resolving per tree, but per polytomy (hence if an identical polytomy spans several trees, it will only resolve the polytomy once, creating fewer edges than a per-tree approch, which should mean it scales to larger sample sizes better). I'm posting it here so that:

a. I don't lose it (!) b. We can decide if we want something like this in the base tskit library, as a method on a tree sequence. c. @hyl317 can use it if he wants, for ARGweaver compatibility (although his current solution may be faster)

Obviously it needs a fair bit of tidying up. It's also slow, but there's also considerable scope for optimization, I think. It requires the PR I made at #787

import collections
import itertools

import numpy as np
import tskit

def resolve_polytomy(parent_node_id, child_ids, new_nodes_by_time, rng):
    """
    For a polytomy and list of child node ids, return a list of (child, parent) tuples,
    describing a bifurcating tree, rooted at parent_node_id, where the new_nodes_by_time
    have been used to break polytomies. All possible topologies should be equiprobable.
    """
    assert len(child_ids) == len(new_nodes_by_time) + 2
    edges = [[child_ids[0], None], ]  # Introduce a single edge that will be deleted later
    edge_choice = rng.integers(0, np.arange(1, len(child_ids) * 2 - 1, 2))
    tmp_new_node_lab = [parent_node_id] + new_nodes_by_time
    assert len(edge_choice) == len(child_ids) - 1
    for node_lab, child_id, target_edge_id in zip(tmp_new_node_lab, child_ids[1:], edge_choice):
        target_edge = edges[target_edge_id]
        # print("target", target_edge)
        # Insert to keep edges in time order of parent
        edges.insert(target_edge_id, [child_id, node_lab])
        edges.insert(target_edge_id, [target_edge[0], node_lab])
        target_edge[0] = node_lab
    # We need to re-map the internal nodes so that they are in time order
    real_node = iter(new_nodes_by_time)
    edges.pop() # remove the unary node at the top
    node_map = {c: c for c in child_ids}
    # print("orig_edges", edges)
    # print("parent IDs to allocate", new_nodes_by_time)
    # last edge should have the highest node
    node_map[edges[-1][1]] = parent_node_id
    for e in reversed(edges):
        # edges should be in time order - the oldest one can be give the parent_node_id
        if e[1] not in node_map:
            node_map[e[1]] = next(real_node)
        if e[0] not in node_map:
            node_map[e[0]] = next(real_node)
        e[0] = node_map[e[0]]
        e[1] = node_map[e[1]]
    # print("mapped edges", edges)
    assert len(node_map) == len(new_nodes_by_time) + len(child_ids) + 1
    return edges

def resolve_polytomies(ts, *, epsilon=1e-10, random_seed=None):
    """
    For a given parent node, an edge in or an edge out signifies a change in children
    Each time such a change happens, we cut all existing edges with that parent,
    and add the previous portion in to the new edge table. If, previously, there were
    3 or more children for this node, we break the polytomy at random
    """
    rng = np.random.default_rng(seed=random_seed)

    tables = ts.dump_tables()
    edges_table = tables.edges
    nodes_table = tables.nodes
    # Store the left of the existing edges, as we will need to change it if the edge is split
    existing_edges_left = edges_table.left
    # Keep these arrays for handy reading later
    existing_edges_right = edges_table.right
    existing_edges_parent = edges_table.parent
    existing_edges_child = edges_table.child
    existing_node_time = nodes_table.time

    edges_table.clear()

    edges_for_node = collections.defaultdict(set)  # The edge ids dangling from each active node
    nodes_changed = set()

    for interval, e_out, e_in in ts.edge_diffs(include_terminal=True):
        for edge in itertools.chain(e_out, e_in):
            if edge.parent != tskit.NULL:
                nodes_changed.add(edge.parent)

        pos = interval[0]
        for parent_node in nodes_changed:
            child_edge_ids = edges_for_node[parent_node]
            if len(child_edge_ids) >= 3:
                # We have a previous polytomy to break
                parent_time = existing_node_time[parent_node]
                new_nodes = []
                child_ids = existing_edges_child[list(child_edge_ids)]
                remaining_edges = child_edge_ids.copy()
                left = None
                max_time = 0
                for edge_id, child_id in zip(child_edge_ids, child_ids):
                    max_time = max(max_time, existing_node_time[child_id])
                    if left is None:
                        left = existing_edges_left[edge_id]
                    else:
                        assert left == existing_edges_left[edge_id]
                    if existing_edges_right[edge_id] > interval[0]:
                        # make sure we carry on the edge after this polytomy
                        existing_edges_left[edge_id] = pos

                # ADD THE PREVIOUS EDGE SEGMENTS
                dt = min((parent_time - max_time)/(len(child_ids)*2), epsilon)
                # Each broken polytomy of degree N introduces N-2 extra nodes, each at a time
                # slighly less than the parent_time. Create new nodes in order of decreasing time
                new_nodes = [nodes_table.add_row(time=parent_time - (i * dt))
                             for i in range(1, len(child_ids) - 1)]
                # print("new_nodes:", new_nodes, [tables.nodes[n].time for n in new_nodes])
                for new_edge in resolve_polytomy(parent_node, child_ids, new_nodes, rng):
                    edges_table.add_row(left=left, right=pos, child=new_edge[0], parent=new_edge[1])
                    # print("new_edge: left={}, right={}, child={}, parent={}".format(
                    #     left, pos, new_edge[0], new_edge[1]))
            else:
                # Previous node was not a polytomy - just add the edges_out, with modified left
                for edge_id in child_edge_ids:
                    if existing_edges_right[edge_id] == pos:  # this edge has just gone out
                        edges_table.add_row(
                            left=existing_edges_left[edge_id],
                            right=pos,
                            parent=parent_node,
                            child=existing_edges_child[edge_id],
                        )

        for edge in e_out: 
            if edge.parent != tskit.NULL:
                # print("REMOVE", edge.id)
                edges_for_node[edge.parent].remove(edge.id)
        for edge in e_in:
            if edge.parent != tskit.NULL:
                # print("ADD", edge.id)
                edges_for_node[edge.parent].add(edge.id)            

        # Chop if we have created a polytomy: the polytomy itself will be resolved
        # at a future iteration, when any of the edges move in or out of the polytomy
        while nodes_changed:
            node = nodes_changed.pop()
            edge_ids = edges_for_node[node]
            # print("Looking at", node)

            if len(edge_ids) == 0:
                del edges_for_node[node]
            # if this node has changed *to* a polytomy, we need to cut all of the
            # child edges that were previously present by adding the previous segment
            # and left-truncating
            elif len(edge_ids) >= 3:
                # print("Polytomy at", node, " breaking edges")
                for edge_id in edge_ids:
                    if existing_edges_left[edge_id] < interval[0]:
                        tables.edges.add_row(
                            left=existing_edges_left[edge_id],
                            right=interval[0],
                            parent=existing_edges_parent[edge_id],
                            child=existing_edges_child[edge_id],
                        )
                    existing_edges_left[edge_id] = interval[0]
    assert len(edges_for_node) == 0

    tables.edges.squash()
    tables.sort() # Shouldn't need to do this: https://github.com/tskit-dev/tskit/issues/808

    return tables.tree_sequence()

The code can be tested using something like this:

import io
import collections
import tqdm
import time

import msprime
import tsinfer

### Check equiprobable

nodes_polytomy_4 = """\
id      is_sample   population      time
0       1       0               0.00000000000000
1       1       0               0.00000000000000
2       1       0               0.00000000000000
3       1       0               0.00000000000000
4       0       0               1.00000000000000
"""
edges_polytomy_4 = """\
id      left            right           parent  child
0       0.00000000      1.00000000      4       0,1,2,3
"""

poly_ts = tskit.load_text(
    nodes=io.StringIO(nodes_polytomy_4),
    edges=io.StringIO(edges_polytomy_4),
    strict=False,
)

trees = collections.Counter()
for seed in tqdm.trange(1, 100000):
    ts2 = resolve_polytomies(poly_ts, random_seed=seed)
    trees.update([ts2.first().rank()])
print(trees)

### Time on 10000 tip inferred TS and check

ts_old = msprime.simulate(10000, recombination_rate=100, mutation_rate=10, random_seed=123)
sd = tsinfer.SampleData.from_tree_sequence(ts_old, use_times=False)
ts = tsinfer.infer(sd)
print(f"{ts.num_samples} tips ; {ts.num_trees} trees")
start = time.time()
ts2 = resolve_polytomies(ts, random_seed=1)
print("Time (s):", time.time()-start)
for tree in ts2.trees():
    for node in tree.nodes():
        assert tree.num_children(node) < 3
benjeffery commented 4 years ago

Cool, @hyanwong! I assume we will want a method like this in tskit at some point.

One thing I wanted to flag - adding new nodes may invalidate the mutation time requirements if the tree sequence has them, for example if a mutation is just above one of the polytomies to be broken then it would need to be moved above the oldest node in the new set that replaces the polytomy. Clearly an edge-case but thought I should flag!

hyanwong commented 4 years ago

That's a good point @benjeffery, thanks. Actually, I break the polytomy below the focal node, so that the polytomy node does not change time: it's mutations on any nodes below the polytomy which will need checking. I think I can see how to do this: at the moment I set the delta time (dt) value to the smallest difference between the child node times and the parent. I need to also check the mutation times above each child.

What do we do if the time differences are so small that they start to encounter floating point accuracy errors. Or is this so unlikely that we don't care?

benjeffery commented 4 years ago

What do we do if the time differences are so small that they start to encounter floating point accuracy errors. Or is this so unlikely that we don't care?

I think we detect the situation and error out. Only other option is some horrible recursive shuffling right?

hyanwong commented 4 years ago

Yes, I guess so. The question is whether we actively look for this and raise a specific error, or just expect it to bomb out (with e.g. time[parent] <= time[child]) when we try to convert to a tree sequence.

Shall I actually work this up into a PR, if you think it's a useful extra method on a tree_sequence? I guess , following other examples, I could make it an in-place method on a TableCollection, and then create a idempotent (is that the right word?) version for a tree sequence.

hyl317 commented 4 years ago

Hi Yan Wong, This looks great! Personally, I would prefer to raise a specific error rather than time[parent]<=time[child] which is less informative from a user's perspective.

Haven't got a chance to try it out myself but I think this feature would be useful for many others. Wonder what other people in the tskit group think about this feature.

btw, how to access the tskit in which you can supply the include_terminal param to edge_diff? It's not in the latest release, right?

Best, Yilei

On Fri, Aug 28, 2020 at 4:25 AM Yan Wong notifications@github.com wrote:

Yes, I guess so. The question is whether we actively look for this and raise a specific error, or just expect it to bomb out (with e.g. time[parent] <= time[child]) when we try to convert to a tree sequence.

Shall I actually work this up into a PR, if you think it's a useful extra method on a tree_sequence? I guess , following other examples, I could make it an in-place method on a TableCollection, and then create a idempotent (is that the right work) version for a tree sequence.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tskit-dev/tskit/issues/809#issuecomment-682397658, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALNWTMEUUCFZHRW4FGYFWKLSC5SWNANCNFSM4QNL6YRA .

hyanwong commented 4 years ago

Hi @hyl317 - to try it out, until #787 is merged you'll need to install my branch directly, e.g.

python3 -m pip install git+https://github.com/hyanwong@edge_diff_include_terminal#subdirectory=python

Then try out the code.

ISWYM about the specific error, it's just a bit more work to do properly!

hyanwong commented 4 years ago

Hi @hyl317 and @awohns - you can now test this using a single install via the PR I just made. The name has changed (for the time being) to randomly_split_polytomies:

python3 -m pip install git+https://github.com/hyanwong@random-split-polytomy#subdirectory=python

Simply call like

ts_binary = ts.randomly_split_polytomies(random_seed=1)
brianzhang01 commented 4 years ago

I'd be happy to help review code for this.

hyanwong commented 4 years ago

Great, thanks @brianzhang01 - that would be really useful. The PR is at https://github.com/tskit-dev/tskit/pull/815 but I don't know if we want to implement out own PRNG, so that we can have an equivalent C function.

hyanwong commented 2 months ago

Just to note that the opposite: collapse edges into polytomies, only retaining those supported by mutations, is at https://github.com/tskit-dev/tskit/discussions/2926