tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
56 stars 13 forks source link

Investigate why sample nodes attach to deep roots / ultimate ancestor #903

Open hyanwong opened 8 months ago

hyanwong commented 8 months ago

In trees from real data, we often find sample nodes attached via long branches to deep in the trees, often to the root. We should investigate why this is happening.

I wondered at first if these deep-rooted samples would be partially removed if we get the sites in the correct order, which seems to produce much better dating results. But some experimentation seems to suggest this doesn't help much (and unsurprisingly, neither does a round of dating + reinference). The plot below shows that regardless of the sample size, very few sample nodes attach to the root (solid lines). But as soon as any inference is involved, even with the true site dates (dot-dashed lines) we get roughly the same distribution of samples attached to their local root as other inference attempts (dotted, dashed lines)

Screenshot 2024-02-14 at 08 29 21
import tsinfer
import tsdate
import numpy as np
import stdpopsim
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AmericanAdmixture_4B11")
contig = species.get_contig("chr22", mutation_rate=model.mutation_rate, length_multiplier=0.05)
engine = stdpopsim.get_engine("msprime")
true_data = {}
inferred_data = {}
truedate_data = {}
infdated_data = {}

n_reps = 5
for s in (1, 10, 100):
    samples = {"AFR": s, "EUR": s, "ASIA": s, "ADMIX": s}
    num_samples = sum(samples.values()) * 2
    true_data[num_samples] = np.zeros(num_samples)
    inferred_data[num_samples] = np.zeros(num_samples)
    truedate_data[num_samples] = np.zeros(num_samples)
    infdated_data[num_samples] = np.zeros(num_samples)
    for rep in range(n_reps):
        sim_ts = engine.simulate(model, contig, samples, seed=rep+123)
        print(sim_ts.num_samples, "samples,", sim_ts.num_trees, "trees,", sim_ts.num_sites, "sites")
        tTS = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=True)).simplify()
        iTS = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=False)).simplify()
        dTS = tsdate.variational_gamma(
            tsdate.util.split_disjoint_nodes(iTS.simplify()),
            mutation_rate=model.mutation_rate,
            population_size=sim_ts.diversity() / (4 * model.mutation_rate),
        )
        dTS = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(dTS, use_sites_time=True)).simplify()
        for ts, data in zip(
            (sim_ts, tTS, iTS, dTS),
            (true_data, truedate_data, inferred_data, infdated_data),
        ):
            sample_set = set([u for u in ts.samples()])
            d = np.zeros(num_samples)
            tot = 0
            for tree in ts.trees():
                if tree.num_edges > 0:
                    tot += tree.interval.span
                    n_root_attached = len(set(tree.children(tree.root)) & sample_set)
                    d[n_root_attached] += tree.interval.span
            data[num_samples] += d / tot
    for data in (true_data, truedate_data, inferred_data, infdated_data):
        data[num_samples] /= n_reps

### Plot
from matplotlib import pyplot as plt
for data, label, linetype in zip(
    (true_data, truedate_data, inferred_data, infdated_data),
    ("Actual", "True site dates", "Plain inferred", "Inf+date"),
    ("solid", "dashdot", "dotted", "dashed"),
):
    for i, (k, d) in enumerate(data.items()):
        plt.plot(np.arange(len(d)), d, ls=linetype, c=f"C{i}", label=f"{label} ({k} samples)")
plt.xlabel("Number of root samples")
plt.xscale("log")
plt.xlim(0.9, 100)
plt.ylabel("proportion of genome")
plt.yscale("log")
# For tidyness, change the order of the legend items
plt.legend(*(
    [ x[i] for i in [r*len(data)+i for i in range(len(data)) for r in range(0, 4)] ]
    for x in plt.gca().get_legend_handles_labels()
),  labelspacing=0.25)
hyanwong commented 8 months ago

I imagined that the effect would be made worse by mispolarization of ancestral states. However, a brief test indicates that this is probably not the case:

Screenshot 2024-02-13 at 22 22 14
import tsinfer
import tsdate
import numpy as np
import stdpopsim
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AmericanAdmixture_4B11")
contig = species.get_contig("chr22", mutation_rate=model.mutation_rate, length_multiplier=0.05)
engine = stdpopsim.get_engine("msprime")
true_data = {}
inferred_data = {}
truedate_data = {}
infdated_data = {}

n_reps = 4
s = 12
mispolarise_proportions = [0, 0.01, 0.1]
for mispolarise_proportion in mispolarise_proportions:
    samples = {"AFR": s, "EUR": s, "ASIA": s, "ADMIX": s}
    num_samples = sum(samples.values()) * 2
    true_data[mispolarise_proportion] = np.zeros(num_samples)
    inferred_data[mispolarise_proportion] = np.zeros(num_samples)
    infdated_data[mispolarise_proportion] = np.zeros(num_samples)
    for rep in range(n_reps):
        sim_ts = engine.simulate(model, contig, samples, seed=rep+123)
        sd = tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=False)
        aa = sd.sites_ancestral_allele[:] # perfect polarisation has these all 0 
        to_switch = np.random.choice(np.arange(len(aa)), round(mispolarise_proportion * len(aa)), replace=False)
        aa[to_switch] = 1  # assume all sites have 2 or more alleles: just switch to the first alt
        sd_mispol = sd.copy() # make editable
        sd_mispol.sites_ancestral_allele[:] = aa
        sd_mispol.finalise()
        print(sim_ts.num_samples, "samples,", sim_ts.num_trees, "trees,", sim_ts.num_sites, "sites", mispolarise_proportion, "mp")
        iTS = tsinfer.infer(sd_mispol).simplify()
        dTS = tsdate.variational_gamma(
            tsdate.util.split_disjoint_nodes(iTS.simplify()),
            mutation_rate=model.mutation_rate,
            population_size=sim_ts.diversity() / (4 * model.mutation_rate),
        )
        dTS = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(dTS, use_sites_time=True)).simplify()
        for ts, data in zip(
            (sim_ts, iTS, dTS),
            (true_data, inferred_data, infdated_data),
        ):
            sample_set = set([u for u in ts.samples()])
            d = np.zeros(num_samples)
            tot = 0
            for tree in ts.trees():
                if tree.num_edges > 0:
                    tot += tree.interval.span
                    n_root_attached = len(set(tree.children(tree.root)) & sample_set)
                    d[n_root_attached] += tree.interval.span
            data[mispolarise_proportion] += d / tot
    for data in (true_data, inferred_data, infdated_data):
        data[mispolarise_proportion] /= n_reps

from matplotlib import pyplot as plt
for data, label, linetype in zip(
    (true_data, inferred_data, infdated_data),
    ("Actual", "Plain inferred", "Inf+date"),
    ("solid", "dotted", "dashed"),
):
    for i, (k, d) in enumerate(data.items()):
        plt.plot(np.arange(len(d)), d, ls=linetype, c=f"C{i}", label=f"{label} ({k*100}% mispolarised)")
plt.xlabel("Number of root samples")
plt.xscale("log")
plt.xlim(0.9, 100)
plt.ylabel("proportion of genome")
plt.yscale("log")
# For tidyness, change the order of the legend items
plt.legend(*(
    [ x[i] for i in [r*len(data)+i for i in range(len(data)) for r in range(0, 3)] ]
    for x in plt.gca().get_legend_handles_labels()
),  labelspacing=0.25)
jeromekelleher commented 8 months ago

I think it would also be interesting to count the number of parent=ultimate ancestor edges (and total span) for samples. I think it's clear that we don't expect many samples to connect to the root (a max of 1, for a binary tree), so I don't see much value in comparing back to the source trees here.

hyanwong commented 8 months ago

After some discussion, we think it would be useful (a) to see if these deep-rooted samples (especially the ultimate-ancestor ones) are concentrated at the flanks, and (b) whether the switch to the root is caused by running out of ancestral haplotype.

hyanwong commented 7 months ago

(a) to see if these deep-rooted samples (especially the ultimate-ancestor ones) are concentrated at the flanks

It appears as if the right hand side does have an increased number of root-attached samples, but they are also present elsewhere. Below is a demo using an admixed stdpopsim model (which I assume might reasonably reflect dynamic and mixed demographies)

It also appears that reinferring after dating, using the site times, substantially increases the number of root-attached samples. I wonder if this is because we are allocating different times to sites with otherwise identical mutation patterns, and therefore changing the ancestor-building steps. In particular, we might get shorter ancestors because of identical-patterned sites conflicting during build-ancestors:

Screenshot 2024-03-12 at 16 33 50

(grey plots are where we are showing attachment to the local MRCA rather than the gMRCA)


import tsinfer
import tsdate
import numpy as np
import stdpopsim
import scipy
import collections
from matplotlib import pyplot as plt

TsAnc = collections.namedtuple("TsAnc", "ts, anc")

def infer(sd, progress=None, **kwargs):
    anc = tsinfer.generate_ancestors(sd, progress_monitor=progress, **kwargs)
    ats = tsinfer.match_ancestors(sd, anc, progress_monitor=progress)
    final = tsinfer.match_samples(sd, ats, progress_monitor=progress)
    return TsAnc(final, anc)

species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AmericanAdmixture_4B11")
contig = species.get_contig("chr22", mutation_rate=model.mutation_rate, length_multiplier=0.05)
engine = stdpopsim.get_engine("msprime")
s = 100
samples = {"AFR": s, "EUR": s, "ASIA": s, "ADMIX": s}
sim_ts = engine.simulate(model, contig, samples, seed=123)
print(sim_ts.num_samples, "samples,", sim_ts.num_trees, "trees,", sim_ts.num_sites, "sites")
true = infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=True), progress=True)
inferred = infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=False), progress=True)
inferred_simp_ts = inferred.ts.simplify(filter_sites=False)
dated_ts = tsdate.variational_gamma(inferred_simp_ts, mutation_rate=model.mutation_rate)
reinferred = infer(tsinfer.SampleData.from_tree_sequence(dated_ts, use_sites_time=True), progress=True)
redated_ts = tsdate.variational_gamma(reinferred.ts.simplify(filter_sites=False), mutation_rate=model.mutation_rate)

# PLOTTING CODE (hacky!)

MockAncestors = collections.namedtuple("MockAncestors", "ancestors_length, ancestors_time")

def plot_root_attached_samples_ts(true_ts, test_ts, dated_test_ts, ax, ylabel="", inset_label="", is_gMRCA=False, max_y=None, log_times=False):
    def get_mut_times(ts):
        t = ts.nodes_time
        return np.array([t[s.mutations[0].node] if len(s.mutations) > 0 else 0 for s in ts.sites()])

    root_str = "grand_MRCA" if is_gMRCA else "MRCA"
    facecolor = "white" if is_gMRCA else "#DDD"

    x, y, v = [], [], []
    dated_times = []
    if hasattr(test_ts, "ts"):
        ts, anc = test_ts.ts, test_ts.anc
    else:
        ts, anc = test_ts, None
    tree_seqs = [ts]
    if dated_test_ts is not None:
        tree_seqs.append(dated_test_ts)

    true_mut_times = None if true_ts is None else get_mut_times(true_ts) 
    for tmp_ts in tree_seqs:
        mut_times = get_mut_times(tmp_ts)
        if true_mut_times is not None:
            use = np.logical_and(true_mut_times > 0, mut_times > 0)
            v.append(scipy.stats.spearmanr(np.log(mut_times[use]), np.log(true_mut_times[use])).statistic)
    L = 0
    mean_rs = 0
    sample_to_root_edges = collections.Counter()
    sample_set = set([u for u in ts.samples()])
    bad_site_positions = []
    for tree in ts.trees():
        if tree.num_edges == 0:
            continue
        L += tree.span
        root_samples = set(tree.children(tree.root)) & sample_set
        if len(root_samples) > 1:
            bad_site_positions += [s.position for s in tree.sites()]
        for m in tree.mutations():
             if m.node in root_samples:
                sample_to_root_edges[m.edge] += 1
        mean_rs += len(root_samples) * tree.span
        y.append(len(root_samples))
        if len(x) == 0:
            x.append(tree.interval.left)
        x.append(tree.interval.right)

    if max_y is None:
        max_y = max(y)  # doesn't work well because max(y) differs between plots

    if len(v) > 0:
        if len(v) == 1:
            txt = f"Spearmans-r mut logtimes:\n   {v[0]:.5f} (tsinfer only)"
        else:
            txt = f"Spearmans-r mut logtimes:\n   {v[0]:.5f} (tsinfer only)\n   {v[1]:.5f} (+tsdate)"
        ax.text(x[0], max_y - max_y/4, txt, va="top")

    ax_r = ax.twinx()
    ax_r.sharey(ax)
    ax_r.scatter(ts.sites_position, [len(s.mutations) for s in ts.sites()], s=1, c="red", alpha=0.1)
    ax_r.set_ylabel("Number of mutations per site", color="r")
    ax.stairs(y, x)
    ax.set_facecolor(facecolor)
    ax.set_ylim(0, max_y)
    ax_r.set_ylim(0, max_y)
    ax.set_ylabel(ylabel)
    ax.text(x[0], max_y - max_y/15, f"Av # sample→{root_str} per tree: {mean_rs/L:.6f}")
    if len(sample_to_root_edges) > 0:
        ax.text(x[0], max_y - max_y/5, f"Av # sites per sample→root edge: {sample_to_root_edges.total()/len(sample_to_root_edges):.6f}")

    # INSET AXES
    date_ax = None
    anc_ax=None
    if anc is not None:
        anc_ax = ax_r.inset_axes([0.75, 0.4, 0.19, 0.57])
        order = np.argsort(anc.ancestors_time)
        anc_ax.hexbin(anc.ancestors_length[order], np.arange(len(anc.ancestors_time)), bins="log", xscale="log")
        anc_ax.set_ylabel(f"Time rank", fontsize=8)
        anc_ax.set_xlabel(f"{inset_label} length", fontsize=8)
        anc_ax.tick_params(labelsize=7)
    if len(tree_seqs) == 2 or ts.time_units != "uncalibrated":
        date_ax = ax_r.inset_axes([0.40, 0.4, 0.29, 0.57])
        tts = tree_seqs[1 if len(tree_seqs) == 2 else 0]
        use = mut_times > 0
        date_ax.scatter(tts.sites_position[use], mut_times[use], s=2, alpha=0.3, c="tab:blue" if len(tree_seqs) == 2 else "tab:green")
        if log_times:
            date_ax.set_yscale("log")
        date_ax.set_ylabel("Dated TS muttime", fontsize=8)
        date_ax.tick_params(labelsize=7)
    return ax, date_ax, anc_ax, bad_site_positions

fig, axes = plt.subplots(5, sharey=True, sharex=True, figsize=(15,15))
returned_data = []

for ax, ylabel, (test, dated_test) in zip(
    axes,
    ["true topology", "infer: true time order", "infer: freq order", "infer: simplify", "infer: tsdate order"],
    [(sim_ts, None), (true, None), (inferred, None), (inferred_simp_ts, dated_ts), (reinferred, redated_ts)],
):
    params = {"is_gMRCA": True}
    try:
        if sim_ts == test:
            params["inset_label"] = "Node"
            # Mock up the true ancestors
            _ndx = np.argsort(sim_ts.edges_child)
            _id, _pos  = np.unique(sim_ts.edges_child[_ndx], return_index=True)
            use = np.logical_not(np.isin(_id, sim_ts.samples()))
            anc = MockAncestors(
                (np.maximum.reduceat(sim_ts.edges_right[_ndx], _pos) - np.minimum.reduceat(sim_ts.edges_left[_ndx], _pos))[use],
                sim_ts.nodes_time[_id][use]
            )
            test = TsAnc(test, anc)
            params["is_gMRCA"] = False
        elif test == inferred_simp_ts:
            params["is_gMRCA"] = False
    except AttributeError:
        pass
    print(ylabel)
    ylabel = "Num root samples\n" + ylabel
    returned_data.append(
        plot_root_attached_samples_ts(sim_ts, test, dated_test, ax, ylabel, max_y=8, **params)
    )

# Nasty logic here because .sharey only shares max range *from* one ax to another
max_date_y = np.argmax([0 if d[1] is None else d[1].get_ylim()[1] for d in returned_data])
max_anc_y = np.argmax([0 if d[2] is None else d[2].get_ylim()[1] for d in returned_data])
for i, (ax, ax_date, ax_anc, _) in enumerate(returned_data):
    if ax_date and i != max_date_y:
        ax_date.sharex(returned_data[max_date_y][1])
        ax_date.sharey(returned_data[max_date_y][1])
    if ax_anc and i != max_anc_y:
        ax_anc.sharex(returned_data[max_anc_y][2])
        ax_anc.sharey(returned_data[max_anc_y][2])
ax.set_xlabel("Genome position")
nspope commented 7 months ago

I wonder if this is because we are allocating different times to sites with otherwise identical mutation patterns, and therefore changing the ancestor-building steps.

It's interesting this doesn't happen when using the true site times, if I'm reading the second panel correctly

hyanwong commented 7 months ago

Here's the plot we get if we change to min_sample_set_size = 0 in the ancestor builder, which should build ancestors as far as possible, until the site is completely incompatible:

image

It reduces the number of root-attached samples a lot when we first use freq order, but actually increases them when we use the true time order?? And although the sample->grandMRCA edges are removed (3rd row), we still get a lot of sample->root ancestors (i.e. after simplification, 4th row). So I'm not sure that building longer ancestors is going to fix the problem.

Also not much effect once we iterate (bottom panel)

hyanwong commented 7 months ago

I adjusted the plots above to add the ancestor lengths too. It's clear that the sample->grandMRCA nodes when true times are used (2nd row, which is as good as we are going to get by repeated iteration) are not caused by running out of ancestral material, as all of the ancestors are of full length.

I wonder if this effect is caused by the fact that using the true time of each site breaks the "nestedness" assumption of the ancestor builder logic? E.g. if an adjacent site is older than the focal site, but at lower freq, it's still treated as potentially contributing an ancestral state to the haplotype.

jeromekelleher commented 7 months ago

Hmm, that is interesting. We need to really investigate what's happening here when we order sites by time. Maybe a slightly different hueristic is called for.

hyanwong commented 7 months ago

Two minor extra tests to check the source of this behaviour (both use min_sample_set_size = 0 so have full-length ancestors). I didn't expect these tweaks to make much difference, but worth checking, I think.

Do not break long ancestors (makes no difference)

If there are "incompatible" sites between two focal sites with identical patterns, we break those ancestors apart. What if we don't do this breaking? In the case of using true site times, the breaking is probably unnecessary, because each node will have a unique time.

Screenshot 2024-03-06 at 10 16 07

True mutation time used (actually a little worse for # sample->root nodes per tree)

This is what we get if we use the mutation time (not the node time) in the 2nd plot:

Screenshot 2024-03-06 at 12 08 26
hyanwong commented 7 months ago

And for reference, when we use the true ancestral haplotypes in the correct order we don't get any sample-to-root problems (code below the figure). So it looks like it is coming from bad reconstruction of ancestral haplotypes.

Screenshot 2024-03-06 at 13 01 14
import tsinfer
import numpy as np
import stdpopsim
import scipy
import collections
from matplotlib import pyplot as plt

TsAnc = collections.namedtuple("TsAnc", "ts, anc")

def infer(sd, progress=None, **kwargs):
    #anc = tsinfer.generate_ancestors(sd, progress_monitor=progress, **kwargs)
    anc = tsinfer.formats.AncestorData(
            sd.sites_position, sd.sequence_length
        )
    tsinfer.eval_util.build_simulated_ancestors(sd, anc, sim_ts)
    anc.finalise()

    ats = tsinfer.match_ancestors(sd, anc, progress_monitor=progress, num_threads=4)
    final = tsinfer.match_samples(sd, ats, progress_monitor=progress, num_threads=4)
    return TsAnc(final, anc)

species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AmericanAdmixture_4B11")
contig = species.get_contig("chr22", mutation_rate=model.mutation_rate, length_multiplier=0.05)
engine = stdpopsim.get_engine("msprime")
s = 100
samples = {"AFR": s, "EUR": s, "ASIA": s, "ADMIX": s}
sim_ts = engine.simulate(model, contig, samples, seed=123, msprime_model="smc_prime")
sim_ts = sim_ts.delete_sites([s.id for s in sim_ts.sites() if len(s.mutations) > 1])
print(sim_ts.num_samples, "samples,", sim_ts.num_trees, "trees,", sim_ts.num_sites, "sites")
true = infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=True), progress=True)
hyanwong commented 7 months ago

So we can partially solve the problem by setting derived states when building the ancestral haplotype only if both the time of the node associated with the mutation is older than the focal node and the frequency of that mutation is greater than the frequency of the mutation at the focal node. We can alter the Python generator to do this by using engine="P" when calling tsinfer.generate_ancestors, and adding the frequency as well as the time into the Site class defined here:

https://github.com/tskit-dev/tsinfer/blob/fa6c364da573ebdfa7be7b14cd1ce6724e8ee436/tsinfer/algorithm.py#L53

(and calling self.sites.append(Site(site_id, time, np.mean(genotypes))) on line 142).

We can then say if self.sites[site_index].time > focal_time and self.sites[site_index].freq > focal_freq here: https://github.com/tskit-dev/tsinfer/blob/fa6c364da573ebdfa7be7b14cd1ce6724e8ee436/tsinfer/algorithm.py#L222

This will only change the ancestor builder at sites to the left or right of the focal site which are old-but-low-frequency. In this case, (a) the generated ancestor will take the ancestral state, even if a majority of focal samples have the derived state (b) the sample set which serves as a list of samples to track will get whittled down to a slightly different subset of samples.

Screenshot 2024-03-06 at 13 25 45

This doesn't actually improve the redating, though (was spearman's r = 0.951 on relating, now 0.950 - no real difference). I'll do some testing on how it affects the iterative procedure in https://github.com/tskit-dev/tsinfer/issues/877#issuecomment-1829826914 and also what happens when we combine this with building full-length ancestors.

hyanwong commented 7 months ago

(b) the sample set which serves as a list of samples to track will get whittled down to a slightly different subset of samples.

We can test if the change in sample_set definitions is the major contributor to the reduced sample->root edges by keeping if self.sites[site_index].time > focal_time: on line 222, but using if g_l[u] != consensus and g_l[u] != tskit.MISSING_DATA and self.sites[site_index].freq < focal_freq: on line 243 instead, which will stop removing the site from the sample set for old, low freq sites. In that case we get longer ancestors, but still lots of root-attached sample nodes when using the true times.

Screenshot 2024-03-08 at 10 07 01
hyanwong commented 7 months ago

If we do the same analysis of using both freq and time as a cutoff, but with the full-length ancestors (min_sample_set_size = 0) we get:

Screenshot 2024-03-06 at 17 10 17
hyanwong commented 7 months ago

It appears that we can largely fix the problem on the first iteration by using the mismatch ratio. Here's for example, is setting mismatch_ratio=0.01, recombination_rate=contig.recombination_map.mean_rate, but using the standard ancestor builder:

Screenshot 2024-03-08 at 10 28 48

This implies again that it is bad haplotype reconstruction that is causing the root-jumps.

jeromekelleher commented 7 months ago

Brilliant - if we can characterise these bad haplotypes we can fix them! I'd rather not use mismatch in ancestor matching unless we have to (various reasons)

hyanwong commented 7 months ago

OK, so I think I know what's going on now.

Imagine you are trying to match a sample. You have basically the correct ancestral haplotype against which to match, and are happily trundling e.g. rightwards along the sample, copying from the correct ancestral haplotype. Now assume that the haplotype is wrongly reconstructed in one position, and has a "bad site": in particular, the sample has a 0 but we're reconstructed the ancestor with a 1 at that position where in truth it should have been a 0. This could be e.g. because this particular site is older (or higher freq) than the focal site, and there's been a recent recombination in a large subset of the samples that has brought that site to >50% freq among the focal samples (rare, but possible).

Since we are not using the mismatch parameter, we are forced to switch to copying from another haplotype. But there aren't any similar haplotypes with a 0 at that position (because other similar ancestral haplotypes from whom you could copy are also likely to have made the same mistake during build_ancestors, as they are closely related). So you temporarily switch to the "best" zero-haplotype. In the absence of any decent haplotypes, you simply pick the oldest with a 0, which is the grand MRCA.

You might hope that after that root jump, the HMM would switch back to the original haplotype from which you were copying, in which case the mismatch parameter should allow us to soak up the ancestor-building-error quite nicely. However, the wrongly-built ancestor will be affected by the bad site, and will be kicking out some of the samples from the set it's using for reconstruction, so the haplotype reconstruction to the right of the bad site will also be a bit wacky, and you might not resume copying correctly from the same haplotype. You might hope that the "buffer" that we use would deal with this reasonably well: it would require 2 sites in a row to show this pattern before we kick out the wrong samples from the sample set.

petrelharp commented 7 months ago

Hmph - I tried hard to get a simulated sequence to do the Bad Thing, but it ain't doing it: here I took Yan's simulation above, and (a) made sample sizes real unbalanced; (b) set up a crazy mutation map so there are lots of 1-bp spacings between mutations; and (c) turned on a bunch of gene conversion; but still: Screenshot from 2024-03-11 13-40-52

nspope commented 7 months ago

How about mispolarisation and/or sequencing error @petrelharp ?

petrelharp commented 7 months ago

That's the next thing I'm working on...

petrelharp commented 7 months ago

Well, assuming I'm mispolarising correctly, that's not doing it either:

s = true.anc.sample_data
sd = tsinfer.SampleData()
mispol_frac = 0.1
for var in s.variants():
    site = var.site
    if rng.uniform() < mispol_frac:
        aa = rng.choice(np.arange(len(site.alleles)), 1, replace=False)
    else:
        aa = site.ancestral_allele
    sd.add_site(site.position, genotypes=var.genotypes, alleles=site.alleles, ancestral_allele=aa)

sd.finalise()
inferred = infer(sd, progress=True)