popsim-consortium / demes-python

Tools for describing and manipulating demographic models.
https://popsim-consortium.github.io/demes-docs/
ISC License
18 stars 6 forks source link

graph utility methods #167

Open grahamgower opened 3 years ago

grahamgower commented 3 years ago

We should consider including methods for graph-theoretic properties that are useful downstream. E.g.

Are there other things that would be useful?

jeromekelleher commented 3 years ago

Traversal functions?

grahamgower commented 3 years ago

Here's a pruning operation (just barely tested, so beware bugs).

def pruned(graph, keep):
    """
    Return a copy of ``graph``, pruned to contain only ancestry for the
    specific deme/time combinations specified in the dictionary ``keep``.

    :param graph: The graph to be pruned.
    :type graph: demes.Graph
    :param keep: Dictionary mapping deme IDs to times.
    :type keep: typing.Dict[str, typing.Union[int, float]]
    :return: the new pruned graph
    :rtype: demes.Graph
    """
    for id, t in keep.items():
        deme = graph[id]
        if not (deme.start_time > t >= deme.end_time):
            raise ValueError(f"time={t} out of bounds for deme {id}")

    pulses_into = collections.defaultdict(list)
    migrations_into = collections.defaultdict(list)
    for pulse in graph.pulses:
        pulses_into[pulse.dest].append(pulse)
    for migration in graph.migrations:
        migrations_into[migration.dest].append(migration)

    pulses_keep = []
    migrations_keep = []

    demes_keep = copy.deepcopy(keep)
    queue = list(demes_keep.keys())
    while len(queue) > 0:
        id = queue.pop()
        deme = graph[id]
        time = demes_keep[id]

        for pulse in pulses_into.get(id, []):
            if pulse.time <= time:
                continue
            if demes_keep.get(pulse.source, float("inf")) <= time:
                continue
            demes_keep[pulse.source] = pulse.time
            pulses_keep.append(copy.deepcopy(pulse))
            queue.append(pulse.source)

        for migration in migrations_into.get(id, []):
            if migration.start_time <= time:
                continue
            if demes_keep.get(migration.source, float("inf")) <= time:
                continue
            end_time = max(migration.end_time, time)
            migrations_keep.append(copy.deepcopy(migration))
            migrations_keep[-1].end_time = end_time
            demes_keep[migration.source] = end_time
            queue.append(migration.source)

        for anc in deme.ancestors:
            anc_time = demes_keep.get(anc, float("inf"))
            if anc_time <= deme.start_time:
                continue
            demes_keep[anc] = deme.start_time
            queue.append(anc)

    g = demes.Graph(
        description=graph.description,
        time_units=graph.time_units,
        generation_time=graph.generation_time
    )
    for deme in graph.demes:
        end_time = demes_keep.get(deme.id)
        if end_time is None:
            continue
        epochs = [copy.deepcopy(e) for e in deme.epochs if e.start_time > end_time]
        epochs[-1].end_time = end_time
        g.deme(
            id=deme.id,
            description=deme.description,
            ancestors=deme.ancestors,
            proportions=deme.proportions,
            epochs=epochs,
        )
    g.migrations = migrations_keep
    g.pulses = pulses_keep

    return g
grahamgower commented 3 years ago

Needs testing.

def connected_subgraphs(graph: demes.Graph):
    """
    Return a list of connected subgraphs of the given graph.

    :param graph: The graph.
    :return: A list of subgraphs.
    :rtype: list[demes.Graph]
    """

    #    if len(graph.demes) == 1:
    #        return [copy.deepcopy(graph)]

    # Find all groups of directly connected demes.
    connected = []
    for deme in graph.demes:
        connected.append(set(deme.ancestors) | set([deme.name]))
    for pulse in graph.pulses:
        connected.append(set([pulse.source, pulse.dest]))
    for migration in graph.migrations:
        connected.append(set([migration.source, migration.dest]))

    # Merge groups.
    while len(connected) > 1:
        merged_groups = False
        for i, a in enumerate(connected):
            for j, b in enumerate(connected[i + 1 :], i + 1):
                if len(a & b) > 0:
                    connected = [
                        group for k, group in enumerate(connected) if k not in (i, j)
                    ]
                    connected.append(a | b)
                    merged_groups = True
                    break
            if merged_groups:
                break
        if not merged_groups:
            # Couldn't merge any groups, so we're done.
            break

    data = graph.asdict()

    subgraphs = []
    for group in connected:
        b = demes.Builder.fromdict(data)
        b.data["demes"] = [deme for deme in b.data["demes"] if deme["name"] in group]
        b.data["migrations"] = [
            migration
            for migration in b.data.get("migrations", [])
            # We need only check the source deme, because source and dest
            # must both be in the same group.
            if migration["source"] in group
        ]
        b.data["pulses"] = [
            pulse for pulse in b.data.get("pulses", []) if pulse["source"] in group
        ]
        subgraphs.append(b.resolve())

    # tests

    assert sum(len(g.demes) for g in subgraphs) == len(graph.demes)

    for i, g1 in enumerate(subgraphs):
        for g2 in connected[i + 1 :]:
            d1 = [deme.name for deme in g1.demes]
            d2 = [deme.name for deme in g2.demes]
            assert len(set(d1) & set(d2)) == 0

    if len(subgraphs) == 1:
        graph.assert_close(subgraphs[0])

    return subgraphs