Open GertjanBisschop opened 1 year ago
Using bitwise operations to keep track of the different branchtypes. To do efficient updates using the ts.edge_diffs()
we should keep track of the type of each branch by storing that value in an array such that branchtype_array[edge.child] == branchtype
.
The following code, initialises this array and the counts for a tree.
def init_bt_array(tree, bt_array, bt_counts):
for node in tree.nodes(order="postorder"):
if tree.is_sample(node):
bt_array[node] = 1<<(node+1)
else:
child = tree.left_child_array[node]
while child != tskit.NULL:
bt_array[node] |= bt_array[child]
child = tree.right_sib_array[child]
bt_counts[bt_array[node]] += 1
>>> t = tskit.Tree.generate_balanced(4)
>>> num_nodes = t.tree_sequence.num_nodes
>>> node_branchtype_array = np.zeros(num_nodes + 1, dtype=np.uint32)
>>> num_branchtypes = sum(2**i for i in range(t.num_samples(), 0, -1))
>>> branchtype_counts = np.zeros(num_branchtypes + 1, dtype=np.uint32)
>>> init_bt_array(t, node_branchtype_array, branchtype_counts)
>>> print(node_branchtype_array)
array([ 2, 4, 8, 16, 6, 24, 30, 0], dtype=uint32)
pinging @dreq.
ts.edge_diffs()
can be used to do this efficiently when iterating across trees.