tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
56 stars 13 forks source link

Investigate causes of multiple mutations #568

Open hyanwong opened 3 years ago

hyanwong commented 3 years ago

We know that we place too many mutations when inferring with a non-zero mismatch ratio (i.e. when specifying a recombination rate). We should investigate if there are any easy ways to reduce this.

hyanwong commented 3 years ago

A good place to start is the example in https://github.com/tskit-dev/tsinfer/issues/429#issuecomment-891818721, where we could have placed a single mutation higher up the tree, and reduced the number of mutations required without actually changing the topology.

Here's another odd example (in this code we also specify a non-zero mismatch in the match_samples phase). We get the topology more-or-less right, but also place a mutation and then a reversion back to the original on the lone branch to sample 3, because the intermediate nodes against which sample 3 are matching (83 and 230) have the derived state at site 382.

import msprime
import tsinfer
import tskit
import json
import collections
import numpy as np
from IPython.display import HTML, SVG

def plot_nonmatching(
    ts, rec_rate, which_bad_site,
    mismatch_ratio=1, path_compression=True, generate_ancestors=tsinfer.generate_ancestors
):
    print("Using simulated TS with", ts.num_trees, "trees and", ts.num_sites, "sites")
    sample_data = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
    ancestors = generate_ancestors(sample_data, engine=tsinfer.PY_ENGINE)
    #ancestors = ancestors.truncate_ancestors(0.4, 0.6)
    ancestors_ts = tsinfer.match_ancestors(
        sample_data, ancestors, recombination_rate=rec_rate, mismatch_ratio=mismatch_ratio, path_compression=path_compression)
    inferred_ts = tsinfer.match_samples(
        sample_data, ancestors_ts, recombination_rate=rec_rate, mismatch_ratio=mismatch_ratio, path_compression=path_compression)

    # which focal ancestors have a mutation above them?
    node_ancestor_map = {}
    ancestor_node_map = {}

    for n in inferred_ts.nodes():
        if n.metadata:
            ancestor_id = json.loads(n.metadata)['ancestor_data_id']
            node_ancestor_map[n.id] = ancestor_id
            assert ancestor_id not in ancestor_node_map
            ancestor_node_map[ancestor_id] = n.id

    site_ancestors = {
        s.id: {node_ancestor_map[m.node] for m in s.mutations if m.node in node_ancestor_map}
        for s in inferred_ts.sites()
    }
    site_focal_ancestors = {}
    site_map = np.where(np.isin(sample_data.sites_position[:], ancestors.sites_position[:]))[0]
    for ancestor in ancestors.ancestors():
        for true_site_id in site_map[ancestor.focal_sites]:
            assert true_site_id not in site_focal_ancestors  # check only 1 focal anc per site
            site_focal_ancestors[true_site_id] = ancestor.id
    assert len(site_focal_ancestors), ancestors_ts.num_sites

    num_inference_sites_with_focal = 0
    num_inference_sites_without_focal = 0

    sites_without_focal = {}
    for site_id, focal_ancestor in site_focal_ancestors.items():
        if focal_ancestor in site_ancestors[site_id]:
            num_inference_sites_with_focal += 1
        else:
            sites_without_focal[site_id] = focal_ancestor

    print(
        inferred_ts.num_mutations,
        "mutations,",
        inferred_ts.num_trees,
        "trees.",
        f"{num_inference_sites_with_focal / len(site_focal_ancestors) * 100:.2f}",
        "% of inference sites have a mutation above the focal ancestor"
    )
    print(
        len(sites_without_focal),
        "sites have no mutation above the focal ancestor (which could be absent in the tree)"
    )

    it = iter(sites_without_focal.items())
    for n in range(which_bad_site):
        bad_site_id, bad_anc_id = next(it)

    # Mark the focal ancestor as a sample, so it is plotted
    bad_pos = inferred_ts.site(bad_site_id).position
    print(f"Plotting nth example (n={which_bad_site}) of a site (pos {bad_pos}) without a focal ancestor mutation.")
    print(" Magenta nodes in the inferred TS have the derived allele (yellow are PC nodes for which info is unknown)")
    tables = inferred_ts.dump_tables()
    # Set the 2 oldest nodes to sensible times
    time = tables.nodes.time
    time[-2] = 1
    time[-1] = 1.001
    tables.nodes.time = time
    if bad_anc_id not in ancestor_node_map:
        print(
            f" (a node corresponding to focal ancestor {bad_anc_id} at site {bad_site_id} is not present in the inferred ts)")
        focal_node = None
    else:
        print(
            " (The magenta-labeled node in the first plot below is the ancestor " +
            f"on which the focal site {bad_site_id} exists)")
        focal_node = ancestor_node_map[bad_anc_id]
        flags = tables.nodes.flags
        flags[focal_node] |= tskit.NODE_IS_SAMPLE
        tables.nodes.flags = flags
    temp_inferred_ts = tables.tree_sequence()
    bad_anc_site_id = np.where(np.isin(ancestors.sites_position[:], bad_pos))[0]
    focal_time = ancestors.ancestors_time[bad_anc_id]
    print(f" (inferred site id {bad_anc_site_id}, focal ancestor id {bad_anc_id},"
          f" line drawn at time of focal ancestor = {focal_time})")
    assert len(bad_anc_site_id) == 1
    has_derived = []
    for ancestor in ancestors.ancestors():
        if ancestor.start <= bad_anc_site_id and ancestor.end > bad_anc_site_id:
            if ancestor.haplotype[bad_anc_site_id -  ancestor.start] != 0:
                if ancestor.id in ancestor_node_map:
                    has_derived.append(ancestor_node_map[ancestor.id])
    style = ".inf svg .node > .sym {fill: grey} .inf svg .node > .lab {fill: grey} .background {display:none} .y-axis .tick .grid {stroke: lightgrey}"
    style += f".inf svg .n{focal_node} > .lab {{fill: magenta}}"
    style += ",".join([".inf svg .n{} > .sym".format(nd) for nd in has_derived]) + "{fill: magenta}"
    # also mark the path compressed nodes, as they won't have haplotypes
    style += ",".join([
        ".inf svg .n{} > .sym".format(nd.id)
        for nd in inferred_ts.nodes()
        if (nd.flags & tsinfer.NODE_IS_PC_ANCESTOR)
    ]) + "{fill: yellow}"
    lims = [bad_pos, np.nextafter(bad_pos, bad_pos+1)]
    display(HTML(
        f"<style>{style}</style>"
        + "<table><tr><td class='inf'>{}</td><td>{}</td></tr></table>".format(
        temp_inferred_ts.draw_svg(
            x_lim=lims, size=(400, 400), x_label="Inferred tree sequence", time_scale="time", y_axis=True,
            y_label="", y_ticks=[focal_time], y_gridlines=True),
        ts.draw_svg(
            x_lim=lims, size=(400, 400), x_label="Original (simulated) tree sequence", time_scale="rank")
    )))

    return ancestors, inferred_ts, node_ancestor_map

r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2)  # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
    ts,
    r,
    which_bad_site=3,  # Change me to look at different sites
    mismatch_ratio=1,
)
Screenshot 2021-08-03 at 21 57 31

Note that this code allows us to specify which_bad_site to plot out and the mismatch ratio, so we can experiment, looking for unusual/unexpected patterns. Hopefully when the mismatch ratio gets smaller, we should tend to the correct topology (as there is no injected sequencing error in this example code)

hyanwong commented 3 years ago

Here's another case from the same example, in which the focal ancestor (14) is actually missing as a node from the entire inferred tree sequence (i.e. not in any tree). Although the topology isn't quite right, we could still trivially reduce the number of mutations to 2 rather than 3 by judicious placement of mutations. I guess we might get a better result if we somehow caused decent mapping to the focal ancestor.

r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2)  # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
    ts,
    r,
    which_bad_site=7,  # Change me to look at different sites
    mismatch_ratio=1,
)
Screenshot 2021-08-03 at 18 58 25

I wonder if it is simply coincidence that the MRCA node which unites the clades exists at the same epoch as the actually (correct) focal ancestor? Another example of the same sort is the 19th bad site in this example:

r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2)  # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
    ts,
    r,
    which_bad_site=19,  # Change me to look at different sites
    mismatch_ratio=1,
)
Screenshot 2021-08-03 at 17 47 16
hyanwong commented 3 years ago

I wonder if it is simply coincidence that the MRCA node which unites the clades exists at the same epoch as the actually (correct) focal ancestor? Another example of the same sort is the 19th bad site in this example:

The code above allows us to plug in a difference ancestors generation function. For instance, this will treat sites of the same age as potentially containing the derived allele:

class ModifiedAncestorBuilder(tsinfer.algorithm.AncestorBuilder):
    def compute_ancestral_states(self, a, focal_site, sites):
        """
        For a given focal site, and set of sites to fill in (usually all the ones
        leftwards or rightwards), augment the haplotype array a with the inferred sites
        Together with `make_ancestor`, which calls this function, these describe the main
        algorithm as implemented in Fig S2 of the preprint, with the buffer.
        """
        focal_time = self.sites[focal_site].time
        S = set(np.where(self.sites[focal_site].genotypes == 1)[0])
        # Break when we've lost half of S
        min_sample_set_size = len(S) // 2
        remove_buffer = []
        last_site = focal_site
        # print("Focal site", focal_site, "time", focal_time)
        for site_index in sites:
            a[site_index] = 0
            last_site = site_index
            #if self.sites[site_index].time > focal_time: ### MODIFIED CODE
            if self.sites[site_index].time >= focal_time:
                g_l = self.sites[site_index].genotypes
                ones = sum(g_l[u] == 1 for u in S)
                zeros = sum(g_l[u] == 0 for u in S)
                # print("pos", site_index, ". Ones:", ones, ". Zeros:", zeros)
                if ones + zeros == 0:
                    a[site_index] = tskit.MISSING_DATA
                else:
                    consensus = 1 if ones >= zeros else 0
                    # print("\tP", site_index, "\t", len(S), ":ones=", ones, consensus)
                    for u in remove_buffer:
                        if g_l[u] != consensus and g_l[u] != tskit.MISSING_DATA:
                            # print("\t\tremoving", u)
                            S.remove(u)
                    a[site_index] = consensus
                    # print("\t", len(S), remove_buffer, consensus, sep="\t")
                    if len(S) <= min_sample_set_size:
                        # print("BREAKING", len(S), min_sample_set_size)
                        break
                    remove_buffer.clear()
                    for u in S:
                        if g_l[u] != consensus and g_l[u] != tskit.MISSING_DATA:
                            remove_buffer.append(u)
        #assert a[last_site] != tskit.MISSING_DATA
        return last_site

    def make_ancestor(self, focal_sites, a):
        """
        Fills out the array a with the haplotype
        return the start and end of an ancestor
        """
        focal_time = self.sites[focal_sites[0]].time
        # check all focal sites in this ancestor are at the same timepoint
        assert all([self.sites[fs].time == focal_time for fs in focal_sites])

        a[:] = tskit.MISSING_DATA
        for focal_site in focal_sites:
            a[focal_site] = 1
        S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
        if len(S) == 0:
            raise ValueError("Cannot compute ancestor for a site at freq 0")
        # Interpolate ancestral haplotype within focal region (i.e. region
        #  spanning from leftmost to rightmost focal site)
        for j in range(len(focal_sites) - 1):
            # Interpolate region between focal site j and focal site j+1
            for site_index in range(focal_sites[j] + 1, focal_sites[j + 1]):
                a[site_index] = 0
                #if self.sites[site_index].time > focal_time: ### MODIFIED CODE
                if self.sites[site_index].time >= focal_time:
                    g_l = self.sites[site_index].genotypes
                    ones = sum(g_l[u] == 1 for u in S)
                    zeros = sum(g_l[u] == 0 for u in S)
                    # print("\t", site_index, ones, zeros, sep="\t")
                    if ones + zeros == 0:
                        a[site_index] = tskit.MISSING_DATA
                    elif ones >= zeros:
                        a[site_index] = 1
        # Extend ancestral haplotype rightwards from rightmost focal site
        focal_site = focal_sites[-1]
        last_site = self.compute_ancestral_states(
            a, focal_site, range(focal_site + 1, self.num_sites)
        )
        #assert a[last_site] != tskit.MISSING_DATA
        end = last_site + 1
        # Extend ancestral haplotype leftwards from leftmost focal site
        focal_site = focal_sites[0]
        last_site = self.compute_ancestral_states(
            a, focal_site, range(focal_site - 1, -1, -1)
        )
        start = last_site
        return start, end

class ModifiedAncestorsGenerator(tsinfer.inference.AncestorsGenerator):
    def __init__(
        self,
        sample_data,
        ancestor_data,
        num_threads=0,
        engine=tsinfer.constants.PY_ENGINE,
        progress_monitor=None,
    ):
        self.sample_data = sample_data
        self.ancestor_data = ancestor_data
        self.progress_monitor = tsinfer.inference._get_progress_monitor(
            progress_monitor, generate_ancestors=True
        )
        self.max_sites = sample_data.num_sites
        self.num_sites = 0
        self.num_samples = sample_data.num_samples
        self.num_threads = num_threads
        if engine == tsinfer.constants.C_ENGINE:
            raise NotImplementedError("Modified version only in python")
            self.ancestor_builder = _tsinfer.AncestorBuilder(
                self.num_samples, self.max_sites
            )
        elif engine == tsinfer.constants.PY_ENGINE:
            self.ancestor_builder = ModifiedAncestorBuilder(
                self.num_samples, self.max_sites
            )
        else:
            raise ValueError(f"Unknown engine:{engine}")

def new_generate_ancestors(
    sample_data,
    *,
    path=None,
    exclude_positions=None,
    num_threads=0,
    # Deliberately undocumented parameters below
    engine=tsinfer.constants.PY_ENGINE,
    progress_monitor=False,
    **kwargs,
):
    sample_data._check_finalised()
    with tsinfer.formats.AncestorData(sample_data, path=path, **kwargs) as ancestor_data:
        generator = ModifiedAncestorsGenerator(
            sample_data,
            ancestor_data,
            num_threads=num_threads,
            engine=engine,
            progress_monitor=progress_monitor,
        )
        generator.add_sites(exclude_positions)
        generator.run()
        ancestor_data.record_provenance("modified-generate-ancestors")
    return ancestor_data

r = 1e-6
ts = msprime.sim_ancestry(5, sequence_length=1e8, recombination_rate=r, random_seed=2)
ts = msprime.sim_mutations(ts, rate=r, random_seed=2)  # same mutation rate as rec rate
ancestors, inferred_ts, node_ancestor_map = plot_nonmatching(
    ts,
    r,
    which_bad_site=19,  # Change me to look at different sites
    mismatch_ratio=1,
    generate_ancestors=new_generate_ancestors,
)

In this case this modification actually does worse, creating 1631 rather than 1501 mutations and 124 rather than 118 trees:

Using simulated TS with 949 trees and 1207 sites
1631 mutations, 124 trees. 60.05 % of inference sites have a mutation above the focal ancestor
314 sites have no mutation above the focal ancestor (which could be absent in the tree)
Plotting nth example (n=19) of a site (pos 93867449.0) without a focal ancestor mutation.