tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
54 stars 13 forks source link

Robustness of tsinfer & tsdate to selection #877

Open hyanwong opened 9 months ago

hyanwong commented 9 months ago

We claim that tsinfer, not being particularly model-based, should be robust to selection. We can test this by carrying out a selective sweep and inferring + dating the result.

The plots below show that there are errors in dating around the position of the sweep.

Screenshot 2023-11-28 at 09 04 46

However, if we repeat the analysis with the true tree sequence, as we might hope there are no dating biases.

Screenshot 2023-11-28 at 09 10 51

This implies that we are doing something wrong in tsinfer. The lower right hand plots imply that the arity of the node below the mutation is seemingly not an issue (@nspope suggested this might be a reason for oddness).

Code to recreate these plots is:

import msprime
from matplotlib import pyplot as plt
import numpy as np
import scipy
import tsinfer
import tsdate

def make_sweep_ts(n, Ne, L, rho=1e-8, mu=1e-8):
    sweep_model = msprime.SweepGenicSelection(
        position=L/2, start_frequency=0.0001, end_frequency=0.9999, s=0.25, dt=1e-6)
    models = [sweep_model, msprime.StandardCoalescent()]
    ts = msprime.sim_ancestry(
        n, model=models, population_size=Ne, sequence_length=L, recombination_rate=rho, random_seed=1234)
    return msprime.sim_mutations(ts, rate=mu, random_seed=4321)

def common_mutation_node_times(ts1, ts2):
    # Return times of nodes below mutations in ts1 and their corresponding nodes in ts2
    # index of "first" mutation at each site: assume most sites have 1 mutation
    _, muts_to_use = np.unique(ts1.mutations_site, return_index=True)
    sites = ts1.mutations_site[muts_to_use]
    nodes = ts1.mutations_node[muts_to_use]
    pos_to_index = {ts1.sites_position[s]: i for i, s in enumerate(sites)}

    node_below_mut = np.full(len(muts_to_use), -1)
    node_time_below_mut = np.full(len(muts_to_use), np.nan)
    for s in ts2.sites():
        idx = pos_to_index[s.position]
        if len(s.mutations) > 0:
            node_time_below_mut[idx] = ts2.nodes_time[s.mutations[0].node]
            node_below_mut[idx] = s.mutations[0].node
    return ts1.nodes_time[nodes], node_time_below_mut, ts1.sites_position[sites], node_below_mut

import scipy
def plots(orig_ts, dated_ts, cutoffs, label="Tsinfer + tsdate times"):
    fig, axes = plt.subplots(2, 2, figsize=(10,7), gridspec_kw={'hspace':0.5})
    axes[0, 0].stairs(orig_ts.diversity(windows="trees", mode="branch"), orig_ts.breakpoints(as_array=True), baseline=None)
    for x in cutoffs:
        axes[0, 0].axvline(x, ls=":", c="tab:orange")
    axes[0, 0].set_xlabel("Genome position (Mb)")
    axes[0, 0].set_title("Branch-length-based diversity")
    axes[0, 0].set_yscale("log")

    x, y, pos, nodes = common_mutation_node_times(orig_ts, dated_ts)
    nonzero = np.logical_and(x > 0, y > 0)
    axes[0, 1].set_title(label)
    axes[0, 1].scatter(
        x[nonzero], y[nonzero],
        alpha=0.1, s=5,
        c=np.where(np.cumprod(pos - cutoffs.reshape((2,1)), axis=0)[1,:] < 0, "tab:orange", "tab:blue")[nonzero])
    axes[0, 1].set_xscale("log")
    axes[0, 1].set_yscale("log")
    axes[0, 1].set_xlabel("True node time")
    axes[0, 1].set_ylabel("Inferred node time")
    axes[0, 1].text(0.05, 0.92, f"Spearmans = {scipy.stats.spearmanr(np.log(x[nonzero]), np.log(y[nonzero])).statistic:.5f}", transform=axes[0, 1].transAxes)
    axes[0, 1].text(0.05, 0.82, f"$R^2 = {np.corrcoef(np.log(x[nonzero]), np.log(y[nonzero]))[0, 1]**2:.5f}$", transform=axes[0, 1].transAxes)
    axes[0, 1].plot(np.logspace(1, 5), np.logspace(1, 5), "-", c="red")

    axes[1, 0].set_title("Residuals vs pos")
    axes[1, 0].scatter(
        pos[nonzero], (np.log(y[nonzero])-np.log(x[nonzero])),
        alpha=0.1, s=5,
        c=np.where(np.cumprod(pos - cutoffs.reshape((2,1)), axis=0)[1,:] < 0, "tab:orange", "tab:blue")[nonzero])

    arity = np.zeros(dated_ts.num_nodes)
    node_spans = np.zeros(dated_ts.num_nodes)
    for tree in dated_ts.trees():
        po = tree.postorder()
        arity[po] += tree.num_children_array[po] * tree.span
        node_spans[po] += tree.span
    arity /= node_spans

    axes[1, 1].set_title("Residuals vs arity")
    axes[1, 1].scatter(
        arity[nodes][nonzero], np.log(y[nonzero])-np.log(x[nonzero]),
        alpha=0.1, s=5,
        c=np.where(np.cumprod(pos - cutoffs.reshape((2,1)), axis=0)[1,:] < 0, "tab:orange", "tab:blue")[nonzero])
    axes[1, 1].set_xlabel("Average arity")
    axes[1, 1].set_ylabel("Residuals")
    axes[1, 1].set_xscale("log")
    axes[1, 1].set_xlim(1.9, 5);
    return axes

mu = 1e-8
pop_size = 10_000
sim_ts = make_sweep_ts(100, Ne=pop_size, L=5_000_000, mu=mu)

i_ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts))
s_ts = i_ts.simplify()
d_ts = tsdate.date(
    s_ts,  # use sim_ts here instead to look at a dated orignal tree sequence
    mutation_rate=mu, population_size=pop_size, method="variational_gamma", max_shape=1000)

print(sim_ts.num_sites, "sites")

plots(sim_ts, d_ts, sim_ts.sequence_length/2 + np.array([-1, 1]) * 200_000)
hyanwong commented 9 months ago

If we iterate, by using the dates from tsdate to feed as site times for tsinfer, we can improve the fit. Here's what happens after 5 rounds of iteration:

Screenshot 2023-11-28 at 12 18 30

And here's the increase in r^2 and spearman's over the iteration rounds:

Screenshot 2023-11-28 at 12 19 16

Incidentally, if we use the tsdated node times to order the ancestors for matching purposes (rather than also using the node times to deduce the ancestral haplotype), we do badly for older nodes:

Screenshot 2023-11-28 at 13 05 20 Screenshot 2023-11-28 at 13 06 31

Here's the code for these:

ts = sim_ts
dv_ts_order = []
r_squared_order = []
spearmans_order = []
sd = tsinfer.SampleData.from_tree_sequence(sim_ts)
ad = tsinfer.generate_ancestors(sd)
for rounds in tqdm.trange(6):
    if rounds > 0:
        ancestor_time = {0: ts.max_time + 2, 1: ts.max_time + 1}
        # Uncomment the line below to use sites_time for haplotype reconstruction as well as matching
        # ad = tsinfer.generate_ancestors(tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=True))
        with tsinfer.AncestorData(sd) as ts_anc:
            pos = ad.sites_position[:]
            ts_anc.set_inference_sites(np.where(np.isin(sd.sites_position[:], pos))[0])
            ts_site = {s.position: s.id for s in ts.sites()}
            for i, sites in enumerate(ad.ancestors_focal_sites[:]):
                if i in ancestor_time:
                    continue
                if len(sites) == 0:
                    print(ad.ancestor(i))
                    raise ValueError
                time = []
                for s in sites:
                    p = pos[s]
                    muts = ts.site(ts_site[p]).mutations
                    assert len(muts) > 0
                    time.append(max(ts.nodes_time[m.node] for m in muts))
                ancestor_time[i] = np.max(time)
            for i in sorted(ancestor_time.keys(), key=lambda k: -ancestor_time[k]):
                a = ad.ancestor(i)
                ts_anc.add_ancestor(a.start, a.end, ancestor_time[i], a.focal_sites, a.haplotype)
    else:
        ts_anc = ad
    a_ts = tsinfer.match_ancestors(sd, ts_anc)
    i_ts = tsinfer.match_samples(sd, a_ts)
    simplified_ts = i_ts.simplify()
    dv_ts_order.append(
        tsdate.date(simplified_ts, mutation_rate=1e-8, population_size=pop_size, method="variational_gamma", max_shape=1000))
    ts = dv_ts_order[-1]
    x, y, pos, nodes = common_mutation_node_times(sim_ts, ts)
    nonzero = np.logical_and(x > 0, y > 0)
    spearmans_order.append(scipy.stats.spearmanr(np.log(x[nonzero]), np.log(y[nonzero])))
    r_squared_order.append(np.corrcoef(np.log(x[nonzero]), np.log(y[nonzero])))

plots(sim_ts, dv_ts_order[-1], sim_ts.sequence_length/2 + np.array([-1, 1]) * 200_000)
plt.show()

plt.scatter(np.arange(len(spearmans_order)), [r[0,1] **2 for r in r_squared_order], label="$r^2$")
plt.scatter(np.arange(len(spearmans_order)), [s.statistic for s in spearmans_order], label="Spearmans")
plt.x_label("round of iteration")
plt.legend()
plt.show()
hyanwong commented 9 months ago

If, as @nspope suggests, we cut the inferred tree sequence into three parts (to the left, centre, and right of the orange zone), we get something like the iterated version within the sweep zone. This implies that indeed, we are porting too much information from adjacent trees (i.e. our ancestral nodes within the swept region are too long, extending into regions of the genome where they shouldn't exist, and we are matching into those regions):

Screenshot 2023-11-28 at 13 59 52

So while Relate probably doesn't tie together nodes enough, tsinfer might tie them together too much? This might also be exacerbated because if there are several equally likely matches during match_ancestors, tsinfer always picks the oldest. This might cause edges to find the same parent more often that you might expect, given a selection of potential ancestors which look identical within the matching region.

# alternatively, split up the tree sequence so that internal nodes are broken either side of the
# cutoff values, so that adjacent regions are not dragging the nodes down
cutoffs = sim_ts.sequence_length/2 + np.array([-1, 1]) * 200_000
i_ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts))
i1_ts = i_ts.keep_intervals([[0, cutoffs[0]]])
i2_ts = i_ts.keep_intervals([[cutoffs[0], cutoffs[1]]])
i3_ts = i_ts.keep_intervals([[cutoffs[1], sim_ts.sequence_length]])
tables = i1_ts.dump_tables()
node_map, site_map = np.arange(i2_ts.num_nodes), np.arange(i2_ts.num_sites)
for n in i2_ts.nodes():
    if not n.is_sample():
        node_map[n.id] = tables.nodes.append(n)
for e in i2_ts.edges():
    tables.edges.append(e.replace(parent=node_map[e.parent], child=node_map[e.child]))
for s in i2_ts.sites():
    site_map[s.id] = tables.sites.append(s)
for m in i2_ts.mutations():
    tables.mutations.append(m.replace(site=site_map[m.site], node=node_map[m.node]))
node_map, site_map = np.arange(i3_ts.num_nodes), np.arange(i3_ts.num_sites)
for n in i3_ts.nodes():
    if not n.is_sample():
        node_map[n.id] = tables.nodes.append(n)
for e in i3_ts.edges():
    tables.edges.append(e.replace(parent=node_map[e.parent], child=node_map[e.child]))
for s in i3_ts.sites():
    site_map[s.id] = tables.sites.append(s)
for m in i3_ts.mutations():
    tables.mutations.append(m.replace(site=site_map[m.site], node=node_map[m.node], parent=tskit.NULL))
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
i_all_ts = tables.tree_sequence()
d_split_ts = tsdate.date(
    i_all_ts, mutation_rate=mu, population_size=pop_size, method="variational_gamma", max_shape=1000)
hyanwong commented 9 months ago

I suspect what is going on when we iteratively improve the situation is that we are forcing ancestors within the swept region to adopt slightly more ancestral states in the regions flanking the focal sites, and also, crucially, to make shorter ancestors (the left-and-right building process will terminate earlier, because adjacent site distributions are more often incompatible with with original pattern of variation at the focal site.

That explains why you need to take the dates into account when actually building ancestors, not just when matching.

Note that Relate might fall into the same trap, because it might use topology to accidentally group edges or nodes between the swept and non-swept regions, when actually they should be separate.

jeromekelleher commented 9 months ago

Can you show a tsqc plot of the ancestor haplotype lengths and edges in the inferred ts here @hyanwong? Should help intuition.

Does this use truncate ancestors?

hyanwong commented 9 months ago

It doesn't use "truncate_ancestors". Here are the edge spans from the "edges tab" in the original tree sequence (time of edge child on the Y axis is log-time via a hack). It is pretty obvious where the selective sweep part of the simulation is (at log time 0 to 5), and above, where the standard neutral coalescent takes over.

Screenshot 2023-11-28 at 15 29 46

and a vertical zoom-in on the dense edge region just at the oldest part of the sweep simulation looks like this (you can see a higher density of edges in the swept region, which is less obvious in the zoomed out version)

Screenshot 2023-11-28 at 15 59 19

Here are the edges from the initial inferred (undated) tree sequence (partially simplified, i.e. with unary nodes), where the "time" is just the frequency of the focal site associated with each ancestor

Screenshot 2023-11-28 at 15 32 11

Here they are from the inferred + dated tree sequence (log time), which is fully simplified. Note that we get the pattern more-or-less right, but there are some long edges which span the sweep region in what should be the neutral part of the timespan

Screenshot 2023-11-28 at 15 33 56

And finally, here are the edges from the iteratively reinforced/redated tree sequence. We have managed to shorten / remove a couple of those long spanning edges with log-time 6 and greater within the sweep region.

Screenshot 2023-11-28 at 15 49 17
jeromekelleher commented 9 months ago

Nice. Do you think it's the long old edges that are causing the problem, or more recent ones that are closely associated with the sweep?

hyanwong commented 9 months ago

Nice. Do you think it's the long old edges that are causing the problem, or more recent ones that are closely associated with the sweep?

Well, the badly dated nodes are all at a time of ~1.5e2 (log time = 5) and are being wrongly dated as time 1e4 (log time ~ 9), so I think the problem is that some of the medium-length edges in the sweep region are being pushed way back in time (you can see these in the edges plot, in the middle of what should be a blank triangle around (2.5e6, 10), (2.5e6, 8.8), and (2.5e6, 7.5) . As @nspope surmised, this is likely to be because the ancestors associated with these edges are attached via other edges to correctly dated ancient nodes.

My suggestion is that when we iteratively refine the ancestor dates, we change the recent ancestral haplotypes in the swept region so that although they might contain high frequency derived variants, they are placed more recently than their frequency suggests. This means that sites which are adjacent to the focal sites will not contain older derived alleles, and so the recent haplotype won't be treated as e.g. a parent of a relatively old haplotype, or an immediate child of an old haplotype.

hyanwong commented 9 months ago

It's true that we are getting a smear of ancestors in the selected region between log times of 5 and 6, which we shouldn't do. I don't think this is the cause of the outliers though: it just shows that we don't have much precision in the sweep region. Here's what we get if we use the correct (simulated) topology, FWIW. It looks really good (obviously!) and captures that discrete boundary between the sweep and the neutral part of history, albeit not quite at the right time

Screenshot 2023-11-28 at 17 30 07

The "cutting into 3 chunks" idea clearly produces a wacky tree sequence, where we mess up the recent times because we don't have sensible spanning over recent ancestors:

Screenshot 2023-11-28 at 17 45 06
hyanwong commented 9 months ago

Aaand - I just realised that I should have looked at the inference we get when the correct site times are used:

i_ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts, use_sites_time=True))
s_ts = i_ts.simplify()
d_orig_times_ts = tsdate.date(
    s_ts, mutation_rate=mu, population_size=pop_size, method="variational_gamma", max_shape=1000)
plots(sim_ts, d_orig_times_ts, sim_ts.sequence_length/2 + np.array([-1, 1]) * 200_000)

This is basically as good as the infer->date->infer iteration loop can ever go. And it's amazing! If we get the site times in the correct order, redating the nodes using tsdate is almost as good as using the correct topology:

Screenshot 2023-11-28 at 18 09 34

The distribution of edge-child times is no way near as neat as if we use the true topology, but it's still pretty clean.

Screenshot 2023-11-28 at 18 11 21
hyanwong commented 9 months ago

Here's the most recent Relate version on the same data, FWIW:

Screenshot 2023-11-29 at 13 56 19 Screenshot 2023-11-29 at 13 49 28
import subprocess
import tempfile

def run_relate(ts, population_size, mu, rho, random_seed=111):
    dir = "examples/"
    prefix = "test"
    with open(dir + prefix + ".haps", "wt") as haps, open(dir + prefix + ".sample", "wt") as sample:
        # ts_to_haps_sample routine from https://github.com/tskit-dev/tsconvert/issues/55#issuecomment-1831959994
        ts_to_haps_sample(ts, haps, sample)
    with tempfile.NamedTemporaryFile("wt") as temp:
        cM_per_MB = rho * 1e8
        print("pos", "COMBINED_rate", "Genetic_Map", sep=" ", file=temp)
        print(0, f"{cM_per_MB:.5f}", 0, sep=" ", file=temp)
        print(
            int(ts.sequence_length),
            f"{cM_per_MB:.5f}",
            ts.sequence_length / 1e6 * cM_per_MB,
            sep=" ",
            file=temp)
        temp.flush()

        params = [
            path_to_relate + "bin/Relate",
            "--haps", prefix+".haps",
            "--sample", prefix+".sample",
            "--map", temp.name,
            "-o", "out",
            "--mode", "All",
            "-m", f"{mu}",
            "-N", f"{population_size}",
            "--seed", f"{random_seed}",
        ]
        print(f"running `{' '.join(params)}`")
        subprocess.run(params, cwd=dir)

    # Convert to tree sequence format
    params = [
        path_to_relatelib + "/bin/Convert",
        "--mode", "ConvertToTreeSequence",
        "--anc", "out.anc",
        "--mut", "out.mut",
        "-o", "out",
        "--compress",
    ]
    print(f"running `{' '.join(params)}`")
    subprocess.run(params, cwd=dir)
    return tskit.load(dir + "out.trees")

relate_ts = run_relate(sim_ts, pop_size/2, mu, rho=1e-8)
plots(sim_ts,relate_ts, sim_ts.sequence_length/2 + np.array([-1, 1]) * 200_000, label="Relate inference")
hyanwong commented 9 months ago

Here are the "dates" from ARG-Needle (mutation times are slightly hacked, as I needed to re-lay the mutations using parsimony). Note that ARG-needle will only run with >300 samples, so I had to change the number of samples in the example to 300 rather than 200. It takes about 30 mins to run on this dataset (~10,000 sites) and produces a tree sequence that has 64 times more edges (1089859 versus 16995 for the tsinferred output)

The r_squared for the tsinferred + tsdate approach in this simulation is 0.807 (compared to the 0.762 below)

Screenshot 2023-11-30 at 12 30 12

Since I had to hack the mutation times, the edges plot might be a more reasonable thing to inspect here. This isn't actually too bad:

Screenshot 2023-11-30 at 10 48 05

Here's the code to run ARG-Needle with sequence data:

import arg_needle_lib

def run_argneedle(ts, population_size, mu, rho):
    """
    Run ARGneedle. Note that the population size is used to create a bespoke "demography"
    file for ARG normalisation
    """
    dir = "examples/"
    prefix = "test"
    with open(dir + prefix + ".haps", "wt") as haps, open(dir + prefix + ".sample", "wt") as sample:
        sites = ts_to_haps_sample(ts, haps, sample)
    with tempfile.NamedTemporaryFile("wt") as map, tempfile.NamedTemporaryFile("wt") as demo:
        # Make the required mapfile (one line per variant)
        # https://palamaralab.github.io/software/argneedle/manual/#genetic-map-mapmapgz
        # chromosome SNP_name genetic_position_cM physical_position_bp
        for s in sites:
            pos = ts.site(s).position
            print("1", f"Site{s}", f"{pos * rho * 100}", f"{pos}", sep="\t", file=temp)
        temp.flush()

        # e.g. from NE10K.demo in
        # https://github.com/PalamaraLab/ASMC_data/blob/main/demographies/NE10K.demo
        print("\t".join(["0.0", str(population_size)]), file=demo)
        print("\t".join(["5000.0", str(population_size)]), file=demo)
        demo.flush()

        params = [
            "arg_needle",
            "--hap_gz", prefix+".haps",
            "--map", map.name,
            "--mode", "sequence",
            "--normalize_demography", demo.name,
            "--out", "arg_needle"
        ]
        print(f"running `{' '.join(params)}`")
        subprocess.run(params, cwd=dir)
        argneedle = arg_needle_lib.deserialize_arg(dir + "arg_needle.argn")
        argn_ts = arg_needle_lib.arg_to_tskit(argneedle)

        # The ARGneedle inference doesn't seem to include mutations
        # in the ARG (see https://github.com/PalamaraLab/arg-needle-lib/issues/2)
        # so for the time being we place these on using parsimony
        tables = argn_ts.dump_tables()
        tables.sequence_length = ts.sequence_length
        assert argn_ts.num_sites == 0
        tables.mutations.clear()  # Clear all mutations in the table collection copy
        variant = tskit.Variant(ts)  # Reuse the same Variant object
        i = 0
        variant.decode(i)  # Efficient if ids are sequential
        for tree in argn_ts.trees():
            while variant.site.position < tree.interval.right:
                anc_state, mutations = tree.map_mutations(
                    variant.genotypes, variant.alleles, ancestral_state=variant.site.ancestral_state)
                s = tables.sites.append(variant.site)
                i += 1
                if i == ts.num_sites:
                    break
                for mut in mutations:
                    tables.mutations.append(mut.replace(site=s, parent=tskit.NULL))
                variant.decode(i)
        tables.compute_mutation_parents()
        tables.compute_mutation_times()
        tables.edges.squash()
        tables.sort()
        return tables.tree_sequence()
jeromekelleher commented 9 months ago

Edge plot is definitely showing it's picking something up re the sweep.

When you look closely there's also these odd curves in there, where the ends of edges seem to form patterns:

Screenshot from 2023-11-30 13-20-33

Nothing to do with the sweep, just an interesting qualitative property of the ARG-Needle edges that shows up in this view.

cc @savitakartik - cool to see tsqc working here on Relate and ARG-Needle!

savitakartik commented 9 months ago

Very interesting to see the patterns in ARGneedle edges, and the edges plot picking up the sweep!

Super interesting work, Yan. I really enjoyed reading this thread.

hyanwong commented 9 months ago

To feed into subsequent rounds of tsinfer topology inference, we could use either the tsdated tree sequence node times (which are constrained by the topology), or the mean times stored in the tsdate node metadata. It appears from the plot below that (if anything) the unconstrained times are a little worse, but this may vary depending on the simulation, I suppose.

Screenshot 2023-12-09 at 21 48 52
hyanwong commented 9 months ago

Also a shocking difference between Nate's "variational gamma" approach and the default "inside-outside" algorithm, only really seen after iterations. I have no idea why the inside-outside does so badly here.

Screenshot 2023-12-09 at 22 33 30
jeromekelleher commented 9 months ago

Interesting! Possible explanation for why previous attempts at iteration didn't seem to work?

hyanwong commented 7 months ago

A major improvement is seen if we use @nspope 's split_disconnected_segments routine. We think this is probably because you definitely don't want to tie together inferred ancestors that continue through the swept region. Here's the equivalent plot:

Screenshot 2024-02-03 at 19 13 02

And here's what it looks like after 6 rounds of iteration:

Screenshot 2024-02-03 at 19 13 25

Edit: here's the edge plot too, e.g. via

from matplotlib import pyplot as plt
import numpy as np
from matplotlib import collections  as mc

tm = ts.nodes_time[ts.edges_child]
lines = np.array([[ts.edges_left, ts.edges_right], [tm, tm]]).T

lc = mc.LineCollection(lines, linewidths=1)
fig, ax = plt.subplots(figsize=(10, 5))
ax.add_collection(lc)
ax.autoscale()
ax.margins(0)
ax.set_yscale("log")
ax.set_ylabel("Time of edge child (generations)")
ax.set_xlabel("Genome position")
Screenshot 2024-02-03 at 19 36 38
nspope commented 7 months ago

This looks great -- there's no visible artefacts in the residuals from the sweep

hyanwong commented 7 months ago

This looks great -- there's no visible artefacts in the residuals from the sweep

Not quite as good as using the true site times (see https://github.com/tskit-dev/tsinfer/issues/877#issuecomment-1830420672), and there is a slight band of orange residuals just below 2, but yes, generally excellent.

hyanwong commented 7 months ago

Interestingly, if we also split the root nodes (see #850) , we initially do better, but then start doing worse on repeated tsinfer/date iteration::

Screenshot 2024-02-06 at 12 07 02
hyanwong commented 7 months ago

The iterative process also seems to improve the mismatch between site and branch-length measures of diversity (see https://github.com/tskit-dev/tsdate/issues/366).

Iteration 0: Site diversity / mu: 20417.261306533586 .   Branch diversity: 15714.866570488393
Iteration 1: Site diversity / mu: 20417.261306533586 .   Branch diversity: 15746.165977907318
Iteration 2: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16230.011702909274
Iteration 3: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16278.369732967092
Iteration 4: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16732.636886473156
Iteration 5: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16408.55652048713

And a closer match is seen if the split_disjoint_nodes function is applied during each round:

Iteration 1: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16311.370935689405
Iteration 2: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16509.801923503386
Iteration 3: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16800.79751714908
Iteration 4: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16806.62455520075
Iteration 5: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16630.727115545968
Iteration 6: Site diversity / mu: 20417.261306533586 .   Branch diversity: 16953.59453426605
hyanwong commented 7 months ago

Another observation: If we set a uniform imperfect prior, using prior_mixture_dim=0, we actually do better in this example (this is with split_disjoint_nodes but without split_root_nodes):

Screenshot 2024-02-13 at 21 21 17

We do better in the first iteration if we use the unconstrained times, but then we tail off more quickly:

Screenshot 2024-02-13 at 21 48 56

In this case we actually overestimate the branch lengths:

Iteration 0: Site diversity / mu: 20417.261306533586 .   Branch diversity: 26534.573807403387
Iteration 1: Site diversity / mu: 20399.170854272277 .   Branch diversity: 27097.63552142596
Iteration 2: Site diversity / mu: 20411.039195980822 .   Branch diversity: 27302.670510059197
Iteration 3: Site diversity / mu: 20411.039195980822 .   Branch diversity: 27333.172768378907
Iteration 4: Site diversity / mu: 20411.039195980822 .   Branch diversity: 27561.999931131995
Iteration 5: Site diversity / mu: 20411.039195980822 .   Branch diversity: 24297.54486174968

(I'm not sure here why the site diversity is changing)

hyanwong commented 4 months ago

The fix for normalisation to account for polytomies (i.e. the new default for tsdate, not using mutational area) noticeably improves matters (especially in the edges plot). We are better here using the unconstrained times (here I have split_disjoint_nodes too)

Screenshot 2024-04-29 at 09 35 16

The main mutation dates in the swept area are much less biased, although there is a slight orange outlier in the residuals plot: I suspect this is our standard problem with unconstrained root time.

Screenshot 2024-04-29 at 09 35 30 Screenshot 2024-04-29 at 09 37 37