tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
54 stars 13 forks source link

Extra parallelization possibilities for match_ancestors? #147

Closed hyanwong closed 1 year ago

hyanwong commented 5 years ago

At the moment we are finding that the slow stage of inference is match_ancestors, since it is only parallelized within an epoch (for ancestors at exactly the same age). This is because we need to make sure that ancestors can copy off all older ancestors if necessary.

As mentioned to @awohns, I think we could, however, parallelize ancestor matching as long as the ancestors we are processing in parallel do not cover the same region of genome. In other words, if we keep a list, L, of the start and end values of the ancestors we are currently matching, I think we can process the next-ancestor-in-time as long as its start and end values don't intersect with the segments in L.

Is my thinking correct? If so, I don't know how much speedup we would get, but I suspect that a long chromosome (when ancestors are short) might be substantially parallelizable,

jeromekelleher commented 5 years ago

This should work, but will require a bit of replumbing. While we're at it, it might be better to grasp the nettle and build a full interval tree of the ancestors and figure out what could possibly overlap with what. It should be possible to build up a work-list based on that. I'm sure there'll be tricky technical details along the way though, so it's not an easy win.

hyanwong commented 3 years ago

@jeromekelleher and I just chatted about some possibilities for this issue, now that we can have smaller old ancestors.

The problem can be thought of as a dependancy graph, where ancestors are dependent on older ancestors that (at least partially) overlap, and those ancestors are dependent on older ancestors, etc.

Maybe this is easier than we think. The dependency graph is, I think, trivial to extract:

ts = msprime.sim_mutations(
    msprime.sim_ancestry(
        1000, population_size=1e4,
        recombination_rate=1e-8,
        sequence_length=1e5, random_seed=1),
    rate=1e-8, random_seed=1)

sd = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
ancestors = tsinfer.generate_ancestors(sd)

start = ancestors.ancestors_start[:]
end = ancestors.ancestors_end[:]
atime = ancestors.ancestors_time[:]

for a in ancestors.ancestors():
    older_dependencies = np.where(np.logical_and.reduce((start < a.end, end > a.start, atime > a.time)))[0]
    print(f"{a.id} depends on {older_dependencies}")

And it turns out that you can pass such a dependency graph to Dask quite easily. It might be that we need to be more sophisticated when updating the ancestors TS, if that's time consuming. But we can test the performance of Dask on this task pretty easily. Here I test it out assuming that each match takes 0.5 seconds, regardless of the number of haplotypes against which we are matching and the length, but of course we can change that to be more realistic:

import tsinfer
import numpy as np
import dask
from dask.diagnostics import ProgressBar
import time
import msprime

def run(ancestor_id, deps):
    time.sleep(0.5)
    # print(f"Done {ancestor_id}: {len(deps)} deps")
    return 1

if __name__ == '__main__':
    ts = msprime.sim_mutations(
        msprime.sim_ancestry(
            1000, population_size=1e4,
            recombination_rate=1e-8,
            sequence_length=1e5, random_seed=1),
        rate=1e-8, random_seed=1)

    sd = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
    ancestors = tsinfer.generate_ancestors(sd)

    start = ancestors.ancestors_start[:]
    end = ancestors.ancestors_end[:]
    atime = ancestors.ancestors_time[:]

    dsk = {}
    for a in ancestors.ancestors():
        dependencies = np.where(np.logical_and.reduce((start < a.end, end > a.start, atime > a.time)))[0]
        dsk[a.id] = (run, f"{a.id}", list(dependencies))

    start = time.time()
    with ProgressBar():
        dask.multiprocessing.get(dsk, list(range(ancestors.num_ancestors)))
    print(
        f"Done with dask + span information in {time.time() - start} seconds, "
        f"using {dask.multiprocessing.CPU_COUNT} cores")

    dsk = {}
    for a in ancestors.ancestors():
        dependencies = np.where(atime > a.time)[0]
        dsk[a.id] = (run, f"{a.id}", list(dependencies))

    start = time.time()
    with ProgressBar():
        dask.multiprocessing.get(dsk, list(range(ancestors.num_ancestors)))
    print(
        f"Done with dask without span information in {time.time() - start} seconds, "
        f"using {dask.multiprocessing.CPU_COUNT} cores")

    start = time.time()
    for i in list(range(ancestors.num_ancestors)):
        run(i, [])
    print(f"Done without dask parallelization in {time.time() - start} seconds")

On my laptop, I get (edit - I posted this with a bug previously)

Done with dask + span information in 84.68570303916931 seconds, using 8 cores
Done with dask without span information in 87.21654891967773 seconds, using 8 cores
Done without dask parallelization in 119.71293210983276 seconds

But I suspect that for a larger tree sequence, especially if we cut the older ancestors down, there will be more of a difference between the parallel versions with and without span information

hyanwong commented 3 years ago

Some results from a similar script to the above. I ran it on the oldest 10,000 ancestors of Wilder's HGDP file, with a dummy function as above but one which slept for 0.01 seconds per ancestor (rather than 0.5 s)

There's good news and bad. The bad is that simply running without any parallelisation was fastest (presumably because 0.01 sec per loop is negligible compared to the cost of setting up the Dask parallelisation.

Done without dask parallelization in 101.39712810516357 seconds
Done with dask + length information in 372.82159185409546 seconds, using 40 cores                                                                    

But for with dask without span information I'm on 15 mins, and still at 0%. This implies that we could get very decent speed-ups with the right approach.

I don't know if relying on the Dask task scheduler to solve the dependency graph is the right way forward, but it's certainly the easiest. It will probably balk at the huge graph we have for all the ancestors, but there's no reason why we can't take (say) the oldest 50,000 ancestors, do all them, then take the next 50,000, and so on.

jeromekelleher commented 3 years ago

Very interesting, thanks @hyanwong! I think it practise we would chunk the work up like you say all right. Keeping the tree sequence used for matching against updated will be the most difficult part of this (particularly if we are distributing across multiple nodes, so that N copies of the tree sequence needs to be maintained and the results of each match must be broadcast to each node), so I doubt that using Dask directly like this will work. But great to get some insights on what the limits of parellelisation are.

Dask has some nice diagnostics maybe we could get some more insights out of them?

hyanwong commented 3 years ago

Here's some code to plot the number of potentially parallelizable matches per "epoch" (i.e. parallelizable slice) using each strategy:

import numpy as np
import tsinfer
import msprime
import tqdm
import networkx as nx
import matplotlib.pyplot as plt

ts = msprime.sim_mutations(
    msprime.sim_ancestry(
        1000, population_size=1e4,
        recombination_rate=1e-7,
        sequence_length=1e5, random_seed=1),
    rate=1e-7, random_seed=1)

sd = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
ancestors = tsinfer.generate_ancestors(sd)

print(f"Num ancestors = {ancestors.num_ancestors}; num_sites = {ancestors.num_sites}")

start = ancestors.ancestors_start[:]
end = ancestors.ancestors_end[:]
atime = ancestors.ancestors_time[:]

G = nx.DiGraph()
for a in tqdm.tqdm(ancestors.ancestors(), total=ancestors.num_ancestors):
    #dependencies = np.where(np.logical_and.reduce((start < a.end, end > a.start, atime > a.time)))[0]
    dependencies = np.where(atime > a.time)[0]
    G.add_node(a.id)
    for d in dependencies:
        G.add_edge(d, a.id)

num_parallel = []
roots = set([0])
level = 0
with tqdm.tqdm(total=G.number_of_edges()) as pbar:
    while len(roots):
        level += 1
        new_roots = set()
        for n in roots:
            for child in list(G.successors(n)):
                pbar.update()
                G.remove_edge(n, child)
                if G.in_degree(child) == 0:
                    new_roots.add(child)

        roots = new_roots
        num_parallel.append(len(roots))

with open("nopos.txt", "wt") as file:
    for v in num_parallel:
        print(v, file=file)
print(f"Non-position-aware parallelization: {len(num_parallel)} epochs")

plt.plot(num_parallel, label="ignore position")

G = nx.DiGraph()
for a in tqdm.tqdm(ancestors.ancestors(), total=ancestors.num_ancestors):
    dependencies = np.where(np.logical_and.reduce((start < a.end, end > a.start, atime > a.time)))[0]
    #dependencies = np.where(atime > a.time)[0]
    G.add_node(a.id)
    for d in dependencies:
        G.add_edge(d, a.id)
num_parallel = []
roots = set([0])
level = 0

with tqdm.tqdm(total=G.number_of_edges()) as pbar:
    while len(roots):
        level += 1
        new_roots = set()
        for n in roots:
            for child in list(G.successors(n)):
                pbar.update(1)
                G.remove_edge(n, child)
                if G.in_degree(child) == 0:
                    new_roots.add(child)
        roots = new_roots
        num_parallel.append(len(roots))
with open("pos.txt", "wt") as file:
    for v in num_parallel:
        print(v, file=file)

print(f"Position-aware parallelization: {len(num_parallel)} epochs")

plt.plot(num_parallel, label="position-aware")
plt.legend()
plt.xlabel("Parallelizable slice (index)")
plt.ylabel("Number of parallelizable LS match processes")

Giving:

Num ancestors = 2057; num_sites = 2733
Non-position-aware paralellization: 754 epochs
Position-aware paralellization: 414 epochs

image

Worth trying out with the real ancestors data file I think.

hyanwong commented 3 years ago

Attached is the code for the version to run on the first 20,000 ancestors of hgdp_1kg_sgdp_chr20_p.missing_binned.truncated.ancestors, which is Wilder's file for the ancestors, which have already been binned into timeslices. Here are the freqs, # haplotypes, & mean % seq length of the first 10 epochs, showing that the average haplotype length is about 10% of the total seq length, and there are lots of haplotypes being binned into the high freq time slices.

0.998668, 177, 10.36
0.998801, 238, 10.36
0.998934, 544, 10.35
0.999068, 131, 10.36
0.999201, 296, 10.36
0.999334, 156, 10.35
0.999467, 1654, 10.36
0.999600, 751, 10.36
1.999600, 1, 99.19
2.999600, 1, 99.19

Here's the text results:

100%|___________________________________________________________________________________________________________| 20000/20000 [06:59<00:00, 47.62it/s]
100%|_______________________________________________________________________________________________| 197849476/197849476 [14:12<00:00, 231993.51it/s]
Non-position-aware parallelization: 2250 epochs
100%|__________________________________________________________________________________________________________| 20000/20000 [01:13<00:00, 270.38it/s]
100%|_________________________________________________________________________________________________| 32687087/32687087 [02:11<00:00, 248334.35it/s]
Position-aware parallelization: 1245 epochs

And the plot (with a log y axis) image

As you can see, this case differs a lot from the simulation. We still have about half the max path length through the dependancy graph, implying that with infinite CPUs we might be able to halve the inference time. But we don't gain extra parallelisation at the start, in the oldest time slices (whereas we do in the simulated case). I'm pretty sure that's because we have already binned the oldest haplotypes together into timeslices, but I'll go back to the originals and see.

The gain in this example comes from being able to parallelize the intermediate freq matches, which are often of the order of 10-20 haplotypes in a time slice.

Together with the previous plot, this says to me that if we implement position-aware parallelisation (a) we are likely to be able to get rid of the time-binning code and (b) possibly speed up by a factor of (?) 2 the ancestor matching.

test_task.py.txt

jeromekelleher commented 3 years ago

This is excellent, thanks @hyanwong. Can we see this on a non-log scale as well please? The absolute values are important here too: if the maximum is rarely > 100, then there's no point in us trying to scale out across multiple nodes (which would be a relief, in a way).

hyanwong commented 3 years ago

Here's non-logged.

I'm also trying this out with ancestors 20000-40000, assuming all the previous ones are available. I suspect this will give much better parallelisation results, as it seems that the number of haplotypes in a timeslice is between 5 and 10. I'll post here when I have something.

image

jeromekelleher commented 3 years ago

Can you cap the ylim to ~250 please?

hyanwong commented 3 years ago

Still stymied by the resolution, I think. The results for haplotypes 20,000-40,000 should be more interesting, I hope.

image

hyanwong commented 3 years ago

The results are more impressive for ancestors 20,000-40,000 from the same datafile. We get from 40 to 80 parallel matches possible for much of the space, versus an average of 5-10 matches when only using time epochs:

100%|___________________________________________________________________________________________________________| 20001/20001 [07:03<00:00, 47.22it/s]
100%|_______________________________________________________________________________________________| 199936078/199936078 [13:23<00:00, 248941.78it/s]
Non-position-aware parallelization: 2909 epochs
100%|_________________________________________________________________________________________________________| 20001/20001 [00:07<00:00, 2778.57it/s]
100%|___________________________________________________________________________________________________| 2610053/2610053 [00:08<00:00, 299116.47it/s]
Position-aware parallelization: 608 epochs
Screenshot 2021-03-15 at 20 17 16
hyanwong commented 3 years ago

And here's the result from the first 20,000 ancestors of HGP+TGP+HGDP chr20 without binning timeslices (horizontal line drawn at 100 parallel threads)

image

If you zoom in to the first 500 you can see how the parallelisation is working - the peaks (which I assume are the integer frequencies) persist longer, and are more squashed together, separated by slightly higher troughs of "substantially non-parallel" matching.

image

It's not quite as impressive as I had hoped, so perhaps there is still a role for binning times? We should get more impressive speed ups for longer chromosomes, though.

hyanwong commented 3 years ago

There's an interesting pattern as we get to intermediate frequencies. If I try this for 100,000 ancestors (about 10% of the total number of ancestors in Wilder's chr20 file) I get a peak, using about 100 processors at a freq of about 0.43. The fact that the blue extends over the x axis for 7x longer than the orange implies to me that for the first 10% of the ancestors, taking account of position should give a ~ seven times speed-up, if we don't time-bin. This is (from my calculations) a comparable speed-up to that gained by time-binning. I suspect it would be more for larger chromosomes.

image

Note that if we continue the blue line by simply plotting the numbers in all the time slices, it goes up right at the end:

image

hyanwong commented 3 years ago

Here's the results for the first 1/2 million of the 1 million ancestors on chr20:

image

And on a non-log scale with points not lines (weird pattern on the right there) image

hyanwong commented 3 years ago

I've been thinking about the best algorithm here @jeromekelleher . Perhaps you have an idea already in place, but it seems to me that the first thing to establish is the involvement of writing the results into the growing tree sequence, and whether this needs to be locked. In particular, if there is a thread in the middle of performing a LS matching algorithm, and all the other threads have finished, is there a way of updating the TS without interfering with the matching process on the remaining thread. I suspect not, and that we would have to wait for that thread to finish before flushing the results to the TS. Is that right? Perhaps I could chat with you about this some time, or maybe Ben would be a good person to ask too?

jeromekelleher commented 3 years ago

The low-level details get quite tricky @hyanwong, but I think we have enough information here to make some decisions:

  1. Looks like it's worth doing this position aware matching dependency graph, as we do get better parallelisation (assuming we can implement it well)
  2. There's no point in going distributed, given we can rarely use more than 100 threads

One more piece of information would be useful actually. How parelleisable is match_ancestors when we have a large number of samples? Here's we're looking at the small samples, long chromosome case. We'd like to know what happens when we have 100K samples. Could we get some insight into this from simulations? We don't want to put a lot of effort in here and later realise we're hitting a wall with the target dataset.

hyanwong commented 3 years ago

Yes, I agree with (1) and (2). Re large numbers of samples, I'll try with a simulation now, but given the difference between simulations and real data shown above, I wonder if we should also see what it looks like for the UKB ancestors, which I presume we still have lying around somewhere?

One issue is that the dependency graph itself takes up a huge amount of space, if done for many ancestors beforehand. I wonder if there's a way to calculate the new ancestors which "become calculable" on the fly, as we complete the matching of older ancestors. I feel there should be a simple method to do this efficiently.

The simplest strategy for parallelisation would be to do what is shown in the plots above: calculate which group of ancestors are currently calculable, parallelise all those, wait until they are done, flush all the results, then repeat. But there will be more sophisticated approaches which involve flushing and recalculating as the results trickle in from the parallelised group. That's going to do my head in, I suspect, although you might find it easier to think about.

jeromekelleher commented 3 years ago

I wonder if we should also see what it looks like for the UKB ancestors, which I presume we still have lying around somewhere?

It'll be in the ukb directory in rescomp, wherever we did all the stuff for the tsinfer paper.

The simplest strategy for parallelisation would be to do what is shown in the plots above: calculate which group of ancestors are currently calculable, parallelise all those, wait until they are done, flush all the results, then repeat. But there will be more sophisticated approaches which involve flushing and recalculating as the results trickle in from the parallelised group. That's going to do my head in, I suspect, although you might find it easier to think about.

Yes, let's not worry about that for now.

hyanwong commented 3 years ago

The UKBB parallelisation works much better than the small example (probably because most timeslices are unique). The maximum number of parallel processes is only about 80. Here's the plot from the truncated ancestors (using Wilder's truncate_ancestors(0.4 0.6) method - there are only 15,000 ancestors in total in this example.

Screenshot 2021-03-17 at 15 01 07
hyanwong commented 3 years ago

I think there's a trivial way to get this parallelisation working. All we need to do is to calculate the "dependency level" for each ancestor. The root is dependency level 0. All the immediate dependents of the root are level 1. All the dependents of level 1 nodes are level 2 (and if there is a level 1 node which is a dependent of another level 1 node, that also switches to become a level 2 node). And so on iteratively down the tree.

Then all we do is treat the "levels" as the "epochs" over which to parallelize (rather than using the time). Simples!

Or am I missing something @jeromekelleher ?

jeromekelleher commented 3 years ago

That could work - I'll need to look closely at the details though

Re UKB, I wouldn't read too much into those ancestors, we know they're pathologically short because of the array data.

hyanwong commented 3 years ago

Here's the result from the first 20,000 ancestors of a 100Mb tree sequence with 100,000 samples, human-like params, which peaks at 200 parallelizable matching instances. Seems like this is worth doing.

Screenshot 2021-03-18 at 09 21 58
jeromekelleher commented 3 years ago

Nice, excellent news!

hyanwong commented 3 years ago

Here's the (trivial) way to get "levels" of dependencies without having to create an enormous graph. I think the logic is right, and I've checked that it gives the same result as the graph-based method for a random test case.

I think one difficulty will be how to implement path compression.

anc = tsinfer.load("myfile.ancestors")
anc_start = anc.ancestors_start[:]
anc_end = anc.ancestors_end[:]
anc_time = anc.ancestors_time[:]

levels = np.zeros(anc.num_ancestors, dtype=int)
for i, (lft, rgt, t) in enumerate(zip(anc_start, anc_end, anc_time)):
    dependencies = np.where(
        np.logical_and.reduce((anc_start < rgt, anc_end > lft, anc_time > t)))[0]
    for dependant_ancestor in dependencies:
        levels[i] = max(levels[dependant_ancestor] + 1, levels[i])

_, level_count = np.unique(levels, return_counts=True)
benjeffery commented 1 year ago

I'm looking into a way to quickly generate the ancestor groupings as the code in #486 takes a long time on our large datasets. It seems a line-sweep is the best way, carefully vectorising where possible:

#Assume start, end and time are the usual numpy arrays:
#First build a list of events
events = []
Event = namedtuple('Event', ['time', 'pos', 'index', 'type'])
# starts are "1", ends are "0"
for i in range(len(time)):
    events.append(Event(time[i], start[i], i, 1))
    events.append(Event(time[i], end[i], i, 0))
# Sort events by position, then ends before starts
events.sort(key=attrgetter('pos', 'type'))

# Sweep line, keeping a count of incoming edges and a list of outgoing edges
active = np.zeros(len(time), dtype=np.int32)
incoming_edge_count = np.zeros(len(time), dtype=np.int32)
children = defaultdict(list)
# Record pairs of nodes that overlap at the same time, as we will need these
# to be in the same grouping.
overlapping_same_time_pairs = []

for event in tqdm(events):
    event_index = event.index
    if event.type == 1: #Ancestor starts
        active_times = time[active == 1]
        active_indices = np.where(active == 1)[0]

        indices_less_than_event_time = active_indices[active_times < event.time]
        incoming_edge_count[indices_less_than_event_time] += 1
        for index in indices_less_than_event_time:
            children[event_index].append(index)

        indices_greater_than_event_time = active_indices[active_times > event.time]
        incoming_edge_count[event_index] += len(indices_greater_than_event_time)
        for index in indices_greater_than_event_time:
            children[index].append(event_index)

        indices_equal_event_time = active_indices[active_times == event.time]
        for index in indices_equal_event_time:
            overlapping_same_time_pairs.append((index, event_index))
        active[event_index] = 1
    else: #Ancestor ends
        active[event_index] = 0

# Now find the groups
group_id = np.full(len(time), -1, dtype=np.int32)
current_group = 0
while True:
    #Find the nodes with no incoming edges
    no_incoming = np.where(incoming_edge_count == 0)[0]
    if len(no_incoming) == 0:
        break
    #Remove them from the graph
    for i in no_incoming:
        incoming_edge_count[i] = -1
        incoming_edge_count[children[i]] -= 1
    #Add them to the group
    group_id[no_incoming] = current_group
    current_group += 1

# Now we need to set the group_id of the overlapping pairs to the maximum of any connected
# ancestor using a DFS to find the connected ancestors

def dfs(node, graph, group, visited):
    visited.add(node)
    for neighbour in graph[node]:
        if neighbour not in visited:
            group.append(neighbour)
            dfs(neighbour, graph, group, visited)

graph = defaultdict(list)
for pair in overlapping_same_time_pairs:
    graph[pair[0]].append(pair[1])
    graph[pair[1]].append(pair[0])

visited = set()
same_time_overlappping_groups = []
for node in graph:
    if node not in visited:
        group = [node]
        dfs(node, graph, group, visited)
        same_time_groups.append(group)

for group in same_time_overlappping_groups:
    group_id[group] = np.max(group_id[group])

On simulated data this runs in 20s for 100,000 ancestors with 100,000 sites.

Next I will get a PR drawn up so we can run on the GeL data.