tskit-dev / tutorials

A set of tutorials for msprime and tskit.
Creative Commons Attribution 4.0 International
17 stars 14 forks source link

Tutorial: tree algorithms with numba #63

Open jeromekelleher opened 3 years ago

jeromekelleher commented 3 years ago

Once we have direct numpy access to the tree arrays in Python (https://github.com/tskit-dev/tskit/issues/1299) I think we should be able to do quite performant traversal algorithms in python using numba. We can illustrate this with a couple of examples:

  1. Implement the get_mrca function, using the algorithm given here: https://github.com/tskit-dev/tskit/discussions/1306 (although we might get hit by the creation of the node_time array if we create it for each function call - we could work around this for the moment by using a copy of the node time array that we keep lying around). This is easy because we only go up the tree.
  2. Implement one of the phylogenetics parsimony algorithms in a simple way. This would illustrate how to do traversals down from the root efficiently. Probably the Sankoff score would be a good one.
jeromekelleher commented 3 years ago

I've been working on this: it's absolutely awesome what numba can do, and it works beautifully with the array based tree representation. I'll post some comments here on what I've done.

I'm using a tree sequence with 1 million samples here as the test case:

ts1m = msprime.sim_ancestry(1e6, ploidy=1, random_seed=42)

MRCAs

First up, compute mrcas using (using the nice new algorithm, https://github.com/tskit-dev/tskit/pull/1313)

ts = ts1m

tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
`for u in range(ts.num_nodes):
    parent[u] = tree.parent(u)

@numba.jit(nopython=True)
def get_mrca_numba(u, v):
    tu = time[u]
    tv = time[v]
    while u != v:
        if tu < tv:
            u = parent[u]
            if u == tskit.NULL:
                return tskit.NULL
            tu = time[u]
        else:
            v = parent[v]
            if v == tskit.NULL:
                return tskit.NULL
            tv = time[v]
    return u

I'm putting the parent array in the notebook context here because we don't have efficient access to it yet. (I've also taken care to warm up the jit before running any timings)

Timings:

%%timeit
get_mrca_numba(u, v)
186 ns ± 0.714 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
%%timeit
tree.get_mrca(u, v)
227 ns ± 2.21 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Whoa! The numba jit version is a little bit faster than the library version (which is the updated, non malloc version)! This is a fast function though, so the overhead of the Python C interface is probably what's creating the difference. But still - I didn't see that coming.

jeromekelleher commented 3 years ago

Total branch length

Let's see how we do with a longer running function that takes a bit more computation. We'll compute the total branch length, but doing a simple top-down traversal. First we get the left_child and right_sib arrays ready, so we can use them.

ts = ts1m

tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
left_child = np.zeros(ts.num_nodes, dtype=np.int32)
right_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
for u in range(ts.num_nodes):
    parent[u] = tree.parent(u)
    left_child[u] = tree.left_child(u)
    right_sib[u] = tree.right_sib(u)

@numba.njit()
def total_branch_length_numba(root):
    tbl = 0
    stack = [root]
    while len(stack) > 0:
        u = stack.pop()
        v = left_child[u]
        while v != tskit.NULL:
            tbl += time[u] - time[v]
            stack.append(v)
            v = right_sib[v]
    return tbl

@numba.njit()
def total_branch_length_numba_recursive(u):
    tbl = 0
    v = left_child[u]
    while v != tskit.NULL:
        tbl += (time[u] - time[v]) + total_branch_length_numba_recursive(v)
        v = right_sib[v]
    return tbl

Timings:

%%timeit
tree.get_total_branch_length(tree.root)
50.8 ms ± 987 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
total_branch_length_numba(tree.root)
158 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
total_branch_length_numba_recursive(tree.root)
122 ms ± 4.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

So, the C version in tree.get_total_branch_length is still faster, but not by much! Surprisingly, the recursive version of the numba jit function is also faster than the manual stack.

For the record, here's what the C code looks like:)

int
tsk_tree_get_total_branch_length(
    const tsk_tree_t *self, tsk_id_t root, double *total_branch_length)
{
    int ret = 0;
    tsk_id_t u, v;
    int stack_top;
    double tbl = 0;
    const double *restrict time = self->tree_sequence->tables->nodes.time;
    const tsk_id_t *restrict right_child = self->right_child;
    const tsk_id_t *restrict left_sib = self->left_sib;
    tsk_id_t *stack = malloc(self->num_nodes * sizeof(*stack));

    if (stack == NULL) {
        ret = TSK_ERR_NO_MEMORY;
        goto out;
    }
    ret = tsk_tree_check_node(self, root);
    if (ret != 0) {
        goto out;
    }
    stack_top = 0;
    stack[stack_top] = root;
    while (stack_top >= 0) {
        u = stack[stack_top];
        stack_top--;
        for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) {
            tbl += time[u] - time[v];
            stack_top++;
            stack[stack_top] = v;
        }
    }
    *total_branch_length = tbl;
out:
    tsk_safe_free(stack);

    return ret;
}

This isn't in the library currently - I might add the function, if we think it's worth while.

jeromekelleher commented 3 years ago

Postorder sum

Since preorder via recursion was fast, let's try postorder (which is much easier to do via recursion). Propagating a sum up the tree is a fundamental operation.

@numba.njit()
def postorder_sum_numba(u, x):   
    v = left_child[u]
    while v != tskit.NULL:
        postorder_sum_numba(v, x)
        x[u] += x[v]
        v = right_sib[v]    

def count_nodes():
    a = np.zeros(ts.num_nodes)
    a[ts.samples()] = 1
    postorder_sum_numba(tree.root, a)
    return a

Here, we just count the number of nodes that are below each node and return this array, but it could be anything. Note in particular that we're using numpy indexing here, and this should work for nd arrays.

Timings:

%%timeit
count_nodes()
107 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Holy jeebus, we did a recursive postorder traversal summing an array in less time than it took to do a top down sum of a simple value! Again, this is just over twice the time it took for the optimised C code to sum the total branch length.

jeromekelleher commented 3 years ago

OK, let's try something a bit more complicated. We can compute the Sankoff parsimony score of an assigment of genotypes to a particular tree.

@numba.njit()
def _sankoff_score_numba(parent, cost_matrix, S):
    num_alleles = cost_matrix.shape[0]
    child = left_child[parent]
    while child != tskit.NULL:
        _sankoff_score_numba(child, cost_matrix, S)
        for j in range(num_alleles):
            min_cost = np.inf
            for k in range(num_alleles):
                min_cost = min(min_cost, cost_matrix[k, j] + S[child, k])
            S[parent, j] += min_cost
        child = right_sib[child]

def sankoff_score_numba(genotypes, cost_matrix):
    num_alleles = cost_matrix.shape[0]
    S = np.zeros((tree.num_nodes, num_alleles))
    samples = tree.tree_sequence.samples()
    S[samples, :] = np.inf
    for allele in range(num_alleles):
        samples_with_allele = samples[genotypes == allele]
        S[samples_with_allele, allele] = 0
    _sankoff_score_numba(tree.root, cost_matrix, S)
    return S   

# Simple 2-allele cost matrix.   
cost_matrix = np.array([[0, 0.5], [0.5, 0]])
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1

Timings:

%%timeit
sankoff_score_numba(genotypes, cost_matrix)
197 ms ± 7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Woo, 0.2 seconds! (This is a tree with 1 million samples, remember). It takes about 30 seconds to run the same calculation in Biopython - but the algorithm is implemented in Python, so it's not a fair comparison.

jeromekelleher commented 3 years ago

Parsimonious assignments

ts = ts1m

tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
right_child = np.zeros(ts.num_nodes, dtype=np.int32)
left_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
flags = ts.tables.nodes.flags
for u in range(ts.num_nodes):
    parent[u] = tree.parent(u)
    right_child[u] = tree.right_child(u)
    left_sib[u] = tree.left_sib(u)

@numba.njit()
def _hartigan_postorder(parent, optimal_set):
    num_alleles = optimal_set.shape[1]
    allele_count = np.zeros(num_alleles, dtype=np.int32)
    child = right_child[parent]
    while child != tskit.NULL:
        _hartigan_postorder(child, optimal_set)
        allele_count += optimal_set[child]
        child = left_sib[child]
    if flags[parent] == 0:  # Bad! This should just be checking the sample bit.
        max_allele_count = np.max(allele_count)
        for j in range(num_alleles):
            if allele_count[j] == max_allele_count:
                optimal_set[parent, j] = 1

@numba.njit()
def _hartigan_preorder(node, state, optimal_set):
    mutations = []
    if optimal_set[node, state] == 0:
        state = np.argmax(optimal_set[node])
        mutations.append((node, state))
    v = right_child[node]
    while v != tskit.NULL:
        v_muts = _hartigan_preorder(v, state, optimal_set)
        mutations.extend(v_muts)
        v = left_sib[v]
    return mutations

def hartigan_map_mutations_numba(tree, genotypes, alleles):
    # Simple version assuming non missing data and one root
    num_alleles = np.max(genotypes) + 1
    num_nodes = tree.tree_sequence.num_nodes

    optimal_set = np.zeros((num_nodes + 1, num_alleles), dtype=np.int8)
    for allele, u in zip(genotypes, tree.tree_sequence.samples()):
        optimal_set[u, allele] = 1

    _hartigan_postorder(tree.root, optimal_set)
    ancestral_state = np.argmax(optimal_set[tree.root])
    ll_mutations = _hartigan_preorder(tree.root, ancestral_state, optimal_set)
    mutations = []
    for node, derived_state in ll_mutations:
        mutations.append(
            tskit.Mutation(
                node=node,
                derived_state=alleles[derived_state],
                # Note we're taking a short-cut here and not bothering with mutation parent. 
                # Could be done easily enough.
            )
        )
    return alleles[ancestral_state], mutations

genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# This is an easy one so we won't be allocing a lot of memory.
genotypes[1] = 1

Timings:

The tskit version takes about 1/3 of second on a million leaf tree

%%timeit 
tree.map_mutations(genotypes, ["0", "1"])
292 ms ± 6.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The numba version takes about a second, which is pretty great!

%%timeit
hartigan_map_mutations_numba(tree, genotypes, ["0", "1"])
1.06 s ± 26.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

They was a very easy parismony job though, it only needed one mutation. Lets try something harder so we're doing memory allocations:

genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1
genotypes

array([1, 0, 1, ..., 0, 1, 0], dtype=int8)
%%timeit 
tree.map_mutations(genotypes, ["0", "1"])
721 ms ± 7.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
hartigan_map_mutations_numba(tree, genotypes, ["0", "1"])
1.56 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So, still within a factor of 2-3 of the highly optimised tskit C code! Wow, numba kicks ass!

jeromekelleher commented 3 years ago

An update here: https://github.com/tskit-dev/tskit/pull/1320 adds support for the tree arrays, which works very well. I'll need to do a bit of experimentation to see what's the best way of passing around these array references (i.e., to make sure they are considered "const" by numba), but it's all solid.