Open grahamgower opened 3 years ago
Traversal functions?
Here's a pruning operation (just barely tested, so beware bugs).
def pruned(graph, keep):
"""
Return a copy of ``graph``, pruned to contain only ancestry for the
specific deme/time combinations specified in the dictionary ``keep``.
:param graph: The graph to be pruned.
:type graph: demes.Graph
:param keep: Dictionary mapping deme IDs to times.
:type keep: typing.Dict[str, typing.Union[int, float]]
:return: the new pruned graph
:rtype: demes.Graph
"""
for id, t in keep.items():
deme = graph[id]
if not (deme.start_time > t >= deme.end_time):
raise ValueError(f"time={t} out of bounds for deme {id}")
pulses_into = collections.defaultdict(list)
migrations_into = collections.defaultdict(list)
for pulse in graph.pulses:
pulses_into[pulse.dest].append(pulse)
for migration in graph.migrations:
migrations_into[migration.dest].append(migration)
pulses_keep = []
migrations_keep = []
demes_keep = copy.deepcopy(keep)
queue = list(demes_keep.keys())
while len(queue) > 0:
id = queue.pop()
deme = graph[id]
time = demes_keep[id]
for pulse in pulses_into.get(id, []):
if pulse.time <= time:
continue
if demes_keep.get(pulse.source, float("inf")) <= time:
continue
demes_keep[pulse.source] = pulse.time
pulses_keep.append(copy.deepcopy(pulse))
queue.append(pulse.source)
for migration in migrations_into.get(id, []):
if migration.start_time <= time:
continue
if demes_keep.get(migration.source, float("inf")) <= time:
continue
end_time = max(migration.end_time, time)
migrations_keep.append(copy.deepcopy(migration))
migrations_keep[-1].end_time = end_time
demes_keep[migration.source] = end_time
queue.append(migration.source)
for anc in deme.ancestors:
anc_time = demes_keep.get(anc, float("inf"))
if anc_time <= deme.start_time:
continue
demes_keep[anc] = deme.start_time
queue.append(anc)
g = demes.Graph(
description=graph.description,
time_units=graph.time_units,
generation_time=graph.generation_time
)
for deme in graph.demes:
end_time = demes_keep.get(deme.id)
if end_time is None:
continue
epochs = [copy.deepcopy(e) for e in deme.epochs if e.start_time > end_time]
epochs[-1].end_time = end_time
g.deme(
id=deme.id,
description=deme.description,
ancestors=deme.ancestors,
proportions=deme.proportions,
epochs=epochs,
)
g.migrations = migrations_keep
g.pulses = pulses_keep
return g
Needs testing.
def connected_subgraphs(graph: demes.Graph):
"""
Return a list of connected subgraphs of the given graph.
:param graph: The graph.
:return: A list of subgraphs.
:rtype: list[demes.Graph]
"""
# if len(graph.demes) == 1:
# return [copy.deepcopy(graph)]
# Find all groups of directly connected demes.
connected = []
for deme in graph.demes:
connected.append(set(deme.ancestors) | set([deme.name]))
for pulse in graph.pulses:
connected.append(set([pulse.source, pulse.dest]))
for migration in graph.migrations:
connected.append(set([migration.source, migration.dest]))
# Merge groups.
while len(connected) > 1:
merged_groups = False
for i, a in enumerate(connected):
for j, b in enumerate(connected[i + 1 :], i + 1):
if len(a & b) > 0:
connected = [
group for k, group in enumerate(connected) if k not in (i, j)
]
connected.append(a | b)
merged_groups = True
break
if merged_groups:
break
if not merged_groups:
# Couldn't merge any groups, so we're done.
break
data = graph.asdict()
subgraphs = []
for group in connected:
b = demes.Builder.fromdict(data)
b.data["demes"] = [deme for deme in b.data["demes"] if deme["name"] in group]
b.data["migrations"] = [
migration
for migration in b.data.get("migrations", [])
# We need only check the source deme, because source and dest
# must both be in the same group.
if migration["source"] in group
]
b.data["pulses"] = [
pulse for pulse in b.data.get("pulses", []) if pulse["source"] in group
]
subgraphs.append(b.resolve())
# tests
assert sum(len(g.demes) for g in subgraphs) == len(graph.demes)
for i, g1 in enumerate(subgraphs):
for g2 in connected[i + 1 :]:
d1 = [deme.name for deme in g1.demes]
d2 = [deme.name for deme in g2.demes]
assert len(set(d1) & set(d2)) == 0
if len(subgraphs) == 1:
graph.assert_close(subgraphs[0])
return subgraphs
We should consider including methods for graph-theoretic properties that are useful downstream. E.g.
roots()
to return the IDs of the graph's root demes.components()
to return a list of graphs, one for each collection of connected demes https://en.wikipedia.org/wiki/Component_(graph_theory).Are there other things that would be useful?