Open hyanwong opened 3 years ago
A good place to start is the example in https://github.com/tskit-dev/tsinfer/issues/429#issuecomment-891818721, where we could have placed a single mutation higher up the tree, and reduced the number of mutations required without actually changing the topology.
Here's another odd example (in this code we also specify a non-zero mismatch in the match_samples phase). We get the topology more-or-less right, but also place a mutation and then a reversion back to the original on the lone branch to sample 3, because the intermediate nodes against which sample 3 are matching (83 and 230) have the derived state at site 382.
import msprime
import tsinfer
import tskit
import json
import collections
import numpy as np
from IPython.display import HTML, SVG
def plot_nonmatching(
ts, rec_rate, which_bad_site,
mismatch_ratio=1, path_compression=True, generate_ancestors=tsinfer.generate_ancestors
):
print("Using simulated TS with", ts.num_trees, "trees and", ts.num_sites, "sites")
sample_data = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
ancestors = generate_ancestors(sample_data, engine=tsinfer.PY_ENGINE)
#ancestors = ancestors.truncate_ancestors(0.4, 0.6)
ancestors_ts = tsinfer.match_ancestors(
sample_data, ancestors, recombination_rate=rec_rate, mismatch_ratio=mismatch_ratio, path_compression=path_compression)
inferred_ts = tsinfer.match_samples(
sample_data, ancestors_ts, recombination_rate=rec_rate, mismatch_ratio=mismatch_ratio, path_compression=path_compression)
# which focal ancestors have a mutation above them?
node_ancestor_map = {}
ancestor_node_map = {}
for n in inferred_ts.nodes():
if n.metadata:
ancestor_id = json.loads(n.metadata)['ancestor_data_id']
node_ancestor_map[n.id] = ancestor_id
assert ancestor_id not in ancestor_node_map
ancestor_node_map[ancestor_id] = n.id
site_ancestors = {
s.id: {node_ancestor_map[m.node] for m in s.mutations if m.node in node_ancestor_map}
for s in inferred_ts.sites()
}
site_focal_ancestors = {}
site_map = np.where(np.isin(sample_data.sites_position[:], ancestors.sites_position[:]))[0]
for ancestor in ancestors.ancestors():
for true_site_id in site_map[ancestor.focal_sites]:
assert true_site_id not in site_focal_ancestors # check only 1 focal anc per site
site_focal_ancestors[true_site_id] = ancestor.id
assert len(site_focal_ancestors), ancestors_ts.num_sites
num_inference_sites_with_focal = 0
num_inference_sites_without_focal = 0
sites_without_focal = {}
for site_id, focal_ancestor in site_focal_ancestors.items():
if focal_ancestor in site_ancestors[site_id]:
num_inference_sites_with_focal += 1
else:
sites_without_focal[site_id] = focal_ancestor
print(
inferred_ts.num_mutations,
"mutations,",
inferred_ts.num_trees,
"trees.",
f"{num_inference_sites_with_focal / len(site_focal_ancestors) * 100:.2f}",
"% of inference sites have a mutation above the focal ancestor"
)
print(
len(sites_without_focal),
"sites have no mutation above the focal ancestor (which could be absent in the tree)"
)
it = iter(sites_without_focal.items())
for n in range(which_bad_site):
bad_site_id, bad_anc_id = next(it)
# Mark the focal ancestor as a sample, so it is plotted
bad_pos = inferred_ts.site(bad_site_id).position
print(f"Plotting nth example (n={which_bad_site}) of a site (pos {bad_pos}) without a focal ancestor mutation.")
print(" Magenta nodes in the inferred TS have the derived allele (yellow are PC nodes for which info is unknown)")
tables = inferred_ts.dump_tables()
# Set the 2 oldest nodes to sensible times
time = tables.nodes.time
time[-2] = 1
time[-1] = 1.001
tables.nodes.time = time
if bad_anc_id not in ancestor_node_map:
print(
f" (a node corresponding to focal ancestor {bad_anc_id} at site {bad_site_id} is not present in the inferred ts)")
focal_node = None
else:
print(
" (The magenta-labeled node in the first plot below is the ancestor " +
f"on which the focal site {bad_site_id} exists)")
focal_node = ancestor_node_map[bad_anc_id]
flags = tables.nodes.flags
flags[focal_node] |= tskit.NODE_IS_SAMPLE
tables.nodes.flags = flags
temp_inferred_ts = tables.tree_sequence()
bad_anc_site_id = np.where(np.isin(ancestors.sites_position[:], bad_pos))[0]
focal_time = ancestors.ancestors_time[bad_anc_id]
print(f" (inferred site id {bad_anc_site_id}, focal ancestor id {bad_anc_id},"
f" line drawn at time of focal ancestor = {focal_time})")
assert len(bad_anc_site_id) == 1
has_derived = []
for ancestor in ancestors.ancestors():
if ancestor.start <= bad_anc_site_id and ancestor.end > bad_anc_site_id:
if ancestor.haplotype[bad_anc_site_id - ancestor.start] != 0:
if ancestor.id in ancestor_node_map:
has_derived.append(ancestor_node_map[ancestor.id])
style = ".inf svg .node > .sym {fill: grey} .inf svg .node > .lab {fill: grey} .background {display:none} .y-axis .tick .grid {stroke: lightgrey}"
style += f".inf svg .n{focal_node} > .lab {{fill: magenta}}"
style += ",".join([".inf svg .n{} > .sym".format(nd) for nd in has_derived]) + "{fill: magenta}"
# also mark the path compressed nodes, as they won't have haplotypes
style += ",".join([
".inf svg .n{} > .sym".format(nd.id)
for nd in inferred_ts.nodes()
if (nd.flags & tsinfer.NODE_IS_PC_ANCESTOR)
]) + "{fill: yellow}"
lims = [bad_pos, np.nextafter(bad_pos, bad_pos+1)]
display(HTML(
f"<style>{style}</style>"
+ "<table><tr><td class='inf'>{}</td><td>{}</td></tr></table>".format(
temp_inferred_ts.draw_svg(
x_lim=lims, size=(400, 400), x_label="Inferred tree sequence", time_scale="time", y_axis=True,
y_label="", y_ticks=[focal_time], y_gridlines=True),
ts.draw_svg(
x_lim=lims, size=(400, 400), x_label="Original (simulated) tree sequence", time_scale="rank")
)))
return ancestors, inferred_ts, node_ancestor_map
r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2) # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
ts,
r,
which_bad_site=3, # Change me to look at different sites
mismatch_ratio=1,
)
Note that this code allows us to specify which_bad_site
to plot out and the mismatch ratio, so we can experiment, looking for unusual/unexpected patterns. Hopefully when the mismatch ratio gets smaller, we should tend to the correct topology (as there is no injected sequencing error in this example code)
Here's another case from the same example, in which the focal ancestor (14) is actually missing as a node from the entire inferred tree sequence (i.e. not in any tree). Although the topology isn't quite right, we could still trivially reduce the number of mutations to 2 rather than 3 by judicious placement of mutations. I guess we might get a better result if we somehow caused decent mapping to the focal ancestor.
r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2) # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
ts,
r,
which_bad_site=7, # Change me to look at different sites
mismatch_ratio=1,
)
I wonder if it is simply coincidence that the MRCA node which unites the clades exists at the same epoch as the actually (correct) focal ancestor? Another example of the same sort is the 19th bad site in this example:
r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2) # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
ts,
r,
which_bad_site=19, # Change me to look at different sites
mismatch_ratio=1,
)
I wonder if it is simply coincidence that the MRCA node which unites the clades exists at the same epoch as the actually (correct) focal ancestor? Another example of the same sort is the 19th bad site in this example:
The code above allows us to plug in a difference ancestors generation function. For instance, this will treat sites of the same age as potentially containing the derived allele:
class ModifiedAncestorBuilder(tsinfer.algorithm.AncestorBuilder):
def compute_ancestral_states(self, a, focal_site, sites):
"""
For a given focal site, and set of sites to fill in (usually all the ones
leftwards or rightwards), augment the haplotype array a with the inferred sites
Together with `make_ancestor`, which calls this function, these describe the main
algorithm as implemented in Fig S2 of the preprint, with the buffer.
"""
focal_time = self.sites[focal_site].time
S = set(np.where(self.sites[focal_site].genotypes == 1)[0])
# Break when we've lost half of S
min_sample_set_size = len(S) // 2
remove_buffer = []
last_site = focal_site
# print("Focal site", focal_site, "time", focal_time)
for site_index in sites:
a[site_index] = 0
last_site = site_index
#if self.sites[site_index].time > focal_time: ### MODIFIED CODE
if self.sites[site_index].time >= focal_time:
g_l = self.sites[site_index].genotypes
ones = sum(g_l[u] == 1 for u in S)
zeros = sum(g_l[u] == 0 for u in S)
# print("pos", site_index, ". Ones:", ones, ". Zeros:", zeros)
if ones + zeros == 0:
a[site_index] = tskit.MISSING_DATA
else:
consensus = 1 if ones >= zeros else 0
# print("\tP", site_index, "\t", len(S), ":ones=", ones, consensus)
for u in remove_buffer:
if g_l[u] != consensus and g_l[u] != tskit.MISSING_DATA:
# print("\t\tremoving", u)
S.remove(u)
a[site_index] = consensus
# print("\t", len(S), remove_buffer, consensus, sep="\t")
if len(S) <= min_sample_set_size:
# print("BREAKING", len(S), min_sample_set_size)
break
remove_buffer.clear()
for u in S:
if g_l[u] != consensus and g_l[u] != tskit.MISSING_DATA:
remove_buffer.append(u)
#assert a[last_site] != tskit.MISSING_DATA
return last_site
def make_ancestor(self, focal_sites, a):
"""
Fills out the array a with the haplotype
return the start and end of an ancestor
"""
focal_time = self.sites[focal_sites[0]].time
# check all focal sites in this ancestor are at the same timepoint
assert all([self.sites[fs].time == focal_time for fs in focal_sites])
a[:] = tskit.MISSING_DATA
for focal_site in focal_sites:
a[focal_site] = 1
S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
if len(S) == 0:
raise ValueError("Cannot compute ancestor for a site at freq 0")
# Interpolate ancestral haplotype within focal region (i.e. region
# spanning from leftmost to rightmost focal site)
for j in range(len(focal_sites) - 1):
# Interpolate region between focal site j and focal site j+1
for site_index in range(focal_sites[j] + 1, focal_sites[j + 1]):
a[site_index] = 0
#if self.sites[site_index].time > focal_time: ### MODIFIED CODE
if self.sites[site_index].time >= focal_time:
g_l = self.sites[site_index].genotypes
ones = sum(g_l[u] == 1 for u in S)
zeros = sum(g_l[u] == 0 for u in S)
# print("\t", site_index, ones, zeros, sep="\t")
if ones + zeros == 0:
a[site_index] = tskit.MISSING_DATA
elif ones >= zeros:
a[site_index] = 1
# Extend ancestral haplotype rightwards from rightmost focal site
focal_site = focal_sites[-1]
last_site = self.compute_ancestral_states(
a, focal_site, range(focal_site + 1, self.num_sites)
)
#assert a[last_site] != tskit.MISSING_DATA
end = last_site + 1
# Extend ancestral haplotype leftwards from leftmost focal site
focal_site = focal_sites[0]
last_site = self.compute_ancestral_states(
a, focal_site, range(focal_site - 1, -1, -1)
)
start = last_site
return start, end
class ModifiedAncestorsGenerator(tsinfer.inference.AncestorsGenerator):
def __init__(
self,
sample_data,
ancestor_data,
num_threads=0,
engine=tsinfer.constants.PY_ENGINE,
progress_monitor=None,
):
self.sample_data = sample_data
self.ancestor_data = ancestor_data
self.progress_monitor = tsinfer.inference._get_progress_monitor(
progress_monitor, generate_ancestors=True
)
self.max_sites = sample_data.num_sites
self.num_sites = 0
self.num_samples = sample_data.num_samples
self.num_threads = num_threads
if engine == tsinfer.constants.C_ENGINE:
raise NotImplementedError("Modified version only in python")
self.ancestor_builder = _tsinfer.AncestorBuilder(
self.num_samples, self.max_sites
)
elif engine == tsinfer.constants.PY_ENGINE:
self.ancestor_builder = ModifiedAncestorBuilder(
self.num_samples, self.max_sites
)
else:
raise ValueError(f"Unknown engine:{engine}")
def new_generate_ancestors(
sample_data,
*,
path=None,
exclude_positions=None,
num_threads=0,
# Deliberately undocumented parameters below
engine=tsinfer.constants.PY_ENGINE,
progress_monitor=False,
**kwargs,
):
sample_data._check_finalised()
with tsinfer.formats.AncestorData(sample_data, path=path, **kwargs) as ancestor_data:
generator = ModifiedAncestorsGenerator(
sample_data,
ancestor_data,
num_threads=num_threads,
engine=engine,
progress_monitor=progress_monitor,
)
generator.add_sites(exclude_positions)
generator.run()
ancestor_data.record_provenance("modified-generate-ancestors")
return ancestor_data
r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2) # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
ts,
r,
which_bad_site=19, # Change me to look at different sites
mismatch_ratio=1,
generate_ancestors=new_generate_ancestors,
)
In this case this modification actually does worse, creating 1631 rather than 1501 mutations and 124 rather than 118 trees:
Using simulated TS with 949 trees and 1207 sites
1631 mutations, 124 trees. 60.05 % of inference sites have a mutation above the focal ancestor
314 sites have no mutation above the focal ancestor (which could be absent in the tree)
Plotting nth example (n=19) of a site (pos 93867449.0) without a focal ancestor mutation.
We know that we place too many mutations when inferring with a non-zero mismatch ratio (i.e. when specifying a recombination rate). We should investigate if there are any easy ways to reduce this.