jeromekelleher / sc2ts

Infer a succinct tree sequence from SARS-COV-2 variation data
MIT License
4 stars 3 forks source link

Replacing UPGMA #311

Closed hyanwong closed 3 days ago

hyanwong commented 4 days ago

Here's some code for @szhan to look at. I reworked the re-rooting algorithm to use a stack, rather than recursion, but I'm not great at stack-based approaches, so I could have something wrong. We can probably take out some of the checks e.g. that the variation is identical.

Should we test this for speed on some of the larger inserted trees?

import tszip
import sc2ts
import logging
from biotite.sequence.phylo import neighbor_joining
import tskit
import numpy as np
import scipy

ts = tszip.load("../data/long_arg_v5_clustloc-rw_10-mgs_20-2021-03-10.ts.tsz")
ti = sc2ts.TreeInfo(ts)

groups = [ 
    "26eb72952849579d42142d845bbccd03", 
    "96538232ff6bcaca2aef10bc3f31a80e", 
# ... add more groups here
] 

def attached_nodes(biotite_node):
    children = biotite_node.children or tuple()
    if biotite_node.parent is None:
        return children
    return children + (biotite_node.parent, )

epsilon = 1e-6
for g in groups:
    sgg = ti.get_sample_group_info(g)
    ts = sgg.ts
    if ts.num_trees != 1:
        roots = {tree.root for tree in ts.trees()}
        if len(roots) > 2:
            raise ValueError("More than one attachment point")
        logging.warning(f"More than one tree ({ts.num_trees}) in {g}")
    # can only use simplify later to match the samples if the originals are at the start
    assert set(ts.samples()) == set(np.arange(ts.num_samples))
    sample_indexes = np.concatenate((ts.samples(), [ts.first().root]))
    G = ts.genotype_matrix(samples=sample_indexes, isolated_as_missing=False)
    Y = scipy.spatial.distance.pdist(G.T, "hamming")
    tree = neighbor_joining(scipy.spatial.distance.squareform(Y))

    tables = ts.dump_tables()
    nodes_time = tables.nodes.time  # only used to get time of sample / leaf nodes
    tables.edges.clear()
    tables.nodes.clear()
    tables.sites.clear()
    tables.mutations.clear()

    root = tree.leaves[-1]  # root is the last one
    time = nodes_time[sample_indexes[root.index]]
    parent = tables.nodes.add_row(time=time)
    L = tables.sequence_length
    stack = [(root, None, parent, time)]
    node_map = {}
    while len(stack) > 0:
        node, prev_node, parent, time = stack.pop()
        for new_node in attached_nodes(node):
            assert new_node is not None
            if new_node is not prev_node:
                if new_node.is_leaf():
                    ts_node = ts.node(sample_indexes[new_node.index])
                    if time <= ts_node.time:
                        raise ValueError(
                            f"Child leaf {sample_indexes[new_node.index]} has time {new_time} but parent {parent} is at time {time}")
                    u = tables.nodes.append(ts_node)
                    new_time = ts_node.time
                else:
                    new_time = time - epsilon
                    u = tables.nodes.add_row(time=new_time)
                assert new_time < tables.nodes[parent].time
                tables.edges.add_row(parent=parent, child=u, left=0, right=L)
                if new_node.is_leaf():
                    node_map[sample_indexes[new_node.index]] = u
                    # print("added internal", u, f"at time {time} (parent is {parent})")
                else:
                    stack.append((new_node, node, u, new_time))
                    # print("made leaf", u, f"(was {sample_indexes[new_node.index]}) at time {time} (parent is {parent})")
    # Make sure that the nodes in the new TS map to those in the old
    tables.sort()
    tables.simplify([node_map[u] for u in np.arange(ts.num_samples)])
    new_ts = tables.tree_sequence()
    assert new_ts.num_samples == ts.num_samples
    assert new_ts.num_trees == 1
    # Now add on mutations
    for v in ts.variants():
        anc, muts = new_ts.first().map_mutations(
            v.genotypes, v.alleles, ancestral_state=v.site.ancestral_state
        )
        site = tables.sites.add_row(v.site.position, anc)
        for mut in muts:
            tables.mutations.add_row(
                site=site,
                node=mut.node,
                derived_state=mut.derived_state,
            )
    new_ts = tables.tree_sequence()
    for v1, v2 in zip(ts.variants(), new_ts.variants()):
        assert np.all(np.array(v1.alleles)[v1.genotypes] == np.array(v2.alleles)[v2.genotypes])
    print("Previous # muts", ts.num_mutations, "- new:", new_ts.num_mutations)
jeromekelleher commented 4 days ago

Nothing wrong with recursion here if it makes things simpler - this bit isn't going to be perf sensitive so clarity wins.

jeromekelleher commented 4 days ago

Feel free to plug this in to infer_binary in a PR and I'll try it out?

hyanwong commented 4 days ago

Nothing wrong with recursion here if it makes things simpler - this bit isn't going to be perf sensitive so clarity wins.

It we have just one huge tree to add, we could hit the recursion depth limit in python. Anyway, the stack approach works now, I think (there's some decent checking in there)

jeromekelleher commented 3 days ago

We've switched over to this - thanks for a great idea @hyanwong!