Open jeromekelleher opened 3 years ago
I've been working on this: it's absolutely awesome what numba can do, and it works beautifully with the array based tree representation. I'll post some comments here on what I've done.
I'm using a tree sequence with 1 million samples here as the test case:
ts1m = msprime.sim_ancestry(1e6, ploidy=1, random_seed=42)
First up, compute mrcas using (using the nice new algorithm, https://github.com/tskit-dev/tskit/pull/1313)
ts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
`for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
@numba.jit(nopython=True)
def get_mrca_numba(u, v):
tu = time[u]
tv = time[v]
while u != v:
if tu < tv:
u = parent[u]
if u == tskit.NULL:
return tskit.NULL
tu = time[u]
else:
v = parent[v]
if v == tskit.NULL:
return tskit.NULL
tv = time[v]
return u
I'm putting the parent
array in the notebook context here because we don't have efficient access to it yet. (I've also taken care to warm up the jit before running any timings)
Timings:
%%timeit
get_mrca_numba(u, v)
186 ns ± 0.714 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
%%timeit
tree.get_mrca(u, v)
227 ns ± 2.21 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Whoa! The numba jit version is a little bit faster than the library version (which is the updated, non malloc version)! This is a fast function though, so the overhead of the Python C interface is probably what's creating the difference. But still - I didn't see that coming.
Let's see how we do with a longer running function that takes a bit more computation. We'll compute the total branch length, but doing a simple top-down traversal. First we get the left_child
and right_sib
arrays ready, so we can use them.
ts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
left_child = np.zeros(ts.num_nodes, dtype=np.int32)
right_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
left_child[u] = tree.left_child(u)
right_sib[u] = tree.right_sib(u)
@numba.njit()
def total_branch_length_numba(root):
tbl = 0
stack = [root]
while len(stack) > 0:
u = stack.pop()
v = left_child[u]
while v != tskit.NULL:
tbl += time[u] - time[v]
stack.append(v)
v = right_sib[v]
return tbl
@numba.njit()
def total_branch_length_numba_recursive(u):
tbl = 0
v = left_child[u]
while v != tskit.NULL:
tbl += (time[u] - time[v]) + total_branch_length_numba_recursive(v)
v = right_sib[v]
return tbl
Timings:
%%timeit
tree.get_total_branch_length(tree.root)
50.8 ms ± 987 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
total_branch_length_numba(tree.root)
158 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
total_branch_length_numba_recursive(tree.root)
122 ms ± 4.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
So, the C version in tree.get_total_branch_length
is still faster, but not by much! Surprisingly, the recursive version of the numba jit function is also faster than the manual stack.
For the record, here's what the C code looks like:)
int
tsk_tree_get_total_branch_length(
const tsk_tree_t *self, tsk_id_t root, double *total_branch_length)
{
int ret = 0;
tsk_id_t u, v;
int stack_top;
double tbl = 0;
const double *restrict time = self->tree_sequence->tables->nodes.time;
const tsk_id_t *restrict right_child = self->right_child;
const tsk_id_t *restrict left_sib = self->left_sib;
tsk_id_t *stack = malloc(self->num_nodes * sizeof(*stack));
if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
ret = tsk_tree_check_node(self, root);
if (ret != 0) {
goto out;
}
stack_top = 0;
stack[stack_top] = root;
while (stack_top >= 0) {
u = stack[stack_top];
stack_top--;
for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) {
tbl += time[u] - time[v];
stack_top++;
stack[stack_top] = v;
}
}
*total_branch_length = tbl;
out:
tsk_safe_free(stack);
return ret;
}
This isn't in the library currently - I might add the function, if we think it's worth while.
Since preorder via recursion was fast, let's try postorder (which is much easier to do via recursion). Propagating a sum up the tree is a fundamental operation.
@numba.njit()
def postorder_sum_numba(u, x):
v = left_child[u]
while v != tskit.NULL:
postorder_sum_numba(v, x)
x[u] += x[v]
v = right_sib[v]
def count_nodes():
a = np.zeros(ts.num_nodes)
a[ts.samples()] = 1
postorder_sum_numba(tree.root, a)
return a
Here, we just count the number of nodes that are below each node and return this array, but it could be anything. Note in particular that we're using numpy indexing here, and this should work for nd arrays.
Timings:
%%timeit
count_nodes()
107 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Holy jeebus, we did a recursive postorder traversal summing an array in less time than it took to do a top down sum of a simple value! Again, this is just over twice the time it took for the optimised C code to sum the total branch length.
OK, let's try something a bit more complicated. We can compute the Sankoff parsimony score of an assigment of genotypes to a particular tree.
@numba.njit()
def _sankoff_score_numba(parent, cost_matrix, S):
num_alleles = cost_matrix.shape[0]
child = left_child[parent]
while child != tskit.NULL:
_sankoff_score_numba(child, cost_matrix, S)
for j in range(num_alleles):
min_cost = np.inf
for k in range(num_alleles):
min_cost = min(min_cost, cost_matrix[k, j] + S[child, k])
S[parent, j] += min_cost
child = right_sib[child]
def sankoff_score_numba(genotypes, cost_matrix):
num_alleles = cost_matrix.shape[0]
S = np.zeros((tree.num_nodes, num_alleles))
samples = tree.tree_sequence.samples()
S[samples, :] = np.inf
for allele in range(num_alleles):
samples_with_allele = samples[genotypes == allele]
S[samples_with_allele, allele] = 0
_sankoff_score_numba(tree.root, cost_matrix, S)
return S
# Simple 2-allele cost matrix.
cost_matrix = np.array([[0, 0.5], [0.5, 0]])
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1
Timings:
%%timeit
sankoff_score_numba(genotypes, cost_matrix)
197 ms ± 7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Woo, 0.2 seconds! (This is a tree with 1 million samples, remember). It takes about 30 seconds to run the same calculation in Biopython - but the algorithm is implemented in Python, so it's not a fair comparison.
ts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
right_child = np.zeros(ts.num_nodes, dtype=np.int32)
left_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
flags = ts.tables.nodes.flags
for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
right_child[u] = tree.right_child(u)
left_sib[u] = tree.left_sib(u)
@numba.njit()
def _hartigan_postorder(parent, optimal_set):
num_alleles = optimal_set.shape[1]
allele_count = np.zeros(num_alleles, dtype=np.int32)
child = right_child[parent]
while child != tskit.NULL:
_hartigan_postorder(child, optimal_set)
allele_count += optimal_set[child]
child = left_sib[child]
if flags[parent] == 0: # Bad! This should just be checking the sample bit.
max_allele_count = np.max(allele_count)
for j in range(num_alleles):
if allele_count[j] == max_allele_count:
optimal_set[parent, j] = 1
@numba.njit()
def _hartigan_preorder(node, state, optimal_set):
mutations = []
if optimal_set[node, state] == 0:
state = np.argmax(optimal_set[node])
mutations.append((node, state))
v = right_child[node]
while v != tskit.NULL:
v_muts = _hartigan_preorder(v, state, optimal_set)
mutations.extend(v_muts)
v = left_sib[v]
return mutations
def hartigan_map_mutations_numba(tree, genotypes, alleles):
# Simple version assuming non missing data and one root
num_alleles = np.max(genotypes) + 1
num_nodes = tree.tree_sequence.num_nodes
optimal_set = np.zeros((num_nodes + 1, num_alleles), dtype=np.int8)
for allele, u in zip(genotypes, tree.tree_sequence.samples()):
optimal_set[u, allele] = 1
_hartigan_postorder(tree.root, optimal_set)
ancestral_state = np.argmax(optimal_set[tree.root])
ll_mutations = _hartigan_preorder(tree.root, ancestral_state, optimal_set)
mutations = []
for node, derived_state in ll_mutations:
mutations.append(
tskit.Mutation(
node=node,
derived_state=alleles[derived_state],
# Note we're taking a short-cut here and not bothering with mutation parent.
# Could be done easily enough.
)
)
return alleles[ancestral_state], mutations
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# This is an easy one so we won't be allocing a lot of memory.
genotypes[1] = 1
Timings:
The tskit version takes about 1/3 of second on a million leaf tree
%%timeit
tree.map_mutations(genotypes, ["0", "1"])
292 ms ± 6.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
The numba version takes about a second, which is pretty great!
%%timeit
hartigan_map_mutations_numba(tree, genotypes, ["0", "1"])
1.06 s ± 26.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
They was a very easy parismony job though, it only needed one mutation. Lets try something harder so we're doing memory allocations:
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1
genotypes
array([1, 0, 1, ..., 0, 1, 0], dtype=int8)
%%timeit
tree.map_mutations(genotypes, ["0", "1"])
721 ms ± 7.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
hartigan_map_mutations_numba(tree, genotypes, ["0", "1"])
1.56 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So, still within a factor of 2-3 of the highly optimised tskit C code! Wow, numba kicks ass!
An update here: https://github.com/tskit-dev/tskit/pull/1320 adds support for the tree arrays, which works very well. I'll need to do a bit of experimentation to see what's the best way of passing around these array references (i.e., to make sure they are considered "const" by numba), but it's all solid.
Once we have direct numpy access to the tree arrays in Python (https://github.com/tskit-dev/tskit/issues/1299) I think we should be able to do quite performant traversal algorithms in python using numba. We can illustrate this with a couple of examples:
node_time
array if we create it for each function call - we could work around this for the moment by using a copy of the node time array that we keep lying around). This is easy because we only go up the tree.