tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
18 stars 7 forks source link

Use of, and alternatives to, to a "global" (i.e. node-agnostic) prior #292

Open hyanwong opened 12 months ago

hyanwong commented 12 months ago

@nspope has found problems with the conditional coalescent mixture priors used by tsdate. Surprisingly, better results are gained by averaging the priors over all the nodes, and using this single distribution as a prior for all nodes (e.g. see https://github.com/tskit-dev/tsdate/pull/257#issuecomment-1532446409).

I have opened up this issue as a place to discuss what the best prior strategy is. We should be able to do better than the naive global approach, but I'm not sure how!

hyanwong commented 12 months ago

Here's an interesting way to plot the expected node times in a 4-tip tree sequence, depending on the mixture of sample descendants. We can use a barycentric (i.e. triangular) plot to represent the ratios of the span over which a node has 2 : 3 : 4 descendants, and then show separate plots for different total node span lengths. Something like this for 100 replicate simulations of 100Mb simulated tree sequences:

image

code:

import msprime
import numpy as np
import tqdm
from matplotlib import pyplot as plt
import scipy

reps = 100

# Save into a massive matrix with 4 columns:
# 2span 3span 4span time
nodes = None
for ts in tqdm.tqdm(
    msprime.sim_ancestry(4, ploidy=1, population_size=1e4, sequence_length=1e8, recombination_rate=1e-8, num_replicates=reps)
):
    node_arr = np.zeros((ts.num_nodes - ts.num_samples, ts.num_samples))
    node_arr[:, 3] = ts.nodes_time[ts.num_samples:]
    for tree in ts.trees():
        for u in tree.nodes():
            if tree.is_internal(u):
                node_arr[u-ts.num_samples][tree.num_samples(u) - 2] += tree.interval.span
    if nodes is None:
        nodes = node_arr
    else:
        nodes = np.vstack((nodes, node_arr))

print("Total data (nodes, spans + time)", nodes.shape)
print("Mean node time is", np.mean(nodes[:,3]), ts.time_units)
a = nodes[:,0]
b = nodes[:,1]
c = nodes[:,2]
tot = a + b + c
times = nodes[:,3]

def get_cartesian_from_barycentric(b, t):
    return t.dot(b)

t = np.transpose(np.array([[0,0],[1,0],[1/2,np.sqrt(3)/2]])) # Triangle
abc = np.array([a/tot, b/tot, c/tot])
xy = get_cartesian_from_barycentric(abc, t)
x = xy[0,:]
y = xy[1,:]
z = np.log(times)

fig, axes = plt.subplots(nrows=3, figsize=(7, 15))
zz = scipy.stats.binned_statistic_2d(x, y, np.log(times), bins=50, statistic="mean")[0]
levels=np.logspace(np.log10(np.nanmin(np.exp(zz))), np.log10(np.nanmax(np.exp(zz))), 10)
levels=np.linspace(np.nanmin(np.exp(zz)), np.nanmax(np.exp(zz)), 10)
for ax, (lower, upper) in zip(
    axes,
    itertools.pairwise(np.quantile(tot, [0, 0.7, 0.9, 1])),
):
    use = np.logical_and(tot >= lower, tot < upper)
    bins = 50
    zz = scipy.stats.binned_statistic_2d(x[use], y[use], z[use], bins=50, statistic="mean")[0].T
    cntr1 = ax.contourf(np.exp(zz), levels=levels)
    ax.set_ylim(-bins * 0.1, bins * 1.1)
    ax.text(*(t[:,0] / np.max(t, axis=1)* bins),
            f"100% {2} sample\ndescendants\n(single tree theor: {2500}) {ts.time_units[:3]}",
            ha="left", va="top")
    ax.text(*(t[:,1] / np.max(t, axis=1)* bins),
            f"100% {2+1} sample\ndescendants\n(single tree theor: {5000}) {ts.time_units[:3]}",
            ha="right", va="top")
    ax.text(*(t[:,2] / np.max(t, axis=1)* bins),
            f"100% {2+2} sample descendants (single tree theor: {15000}) {ts.time_units[:3]}",
            ha="center")
    ax.set_axis_off()
    ax.set_title(f"{np.sum(use)} nodes spanning {lower/1000:5g}-{upper/1000:.5g} kb")
    cbar = fig.colorbar(cntr1, ax=ax)
    cbar.ax.set_ylabel(f"Mean node time ({ts.time_units})")
plt.show()
hyanwong commented 12 months ago

We can see what our (poorly performing) conditional coalescent mixture looks like, by substituting in the expected times for the real times. I.e. putting z = np.log(expected_time) above, where we can find expected time from:

import tsdate
nm = {n: i for i, n in enumerate(tsdate.prior.PriorParams._fields)}
cc = tsdate.prior.ConditionalCoalescentTimes(100)
cc.add(4)

def mixture_expect_and_var(mixture, cond_coal):
    """
    Return the expectation and variance of a coalescent mixture
    mixture is a dict of the form N:{'descendant_tips': [tips], 'weight': [weights]}
    """
    expectation = 0
    first = secnd = 0
    for N, tip_dict in mixture.items():
        # assert 1 not in tip_dict.descendant_tips
        mean = cond_coal[N][tip_dict["descendant_tips"], nm["mean"]]
        var = cond_coal[N][tip_dict["descendant_tips"], nm["var"]]
        # Mixture expectation
        expectation += np.sum(mean * tip_dict["weight"])
        # Mixture variance
        first += np.sum(var * tip_dict["weight"])
        secnd += np.sum(mean**2 * tip_dict["weight"])
    mean = expectation
    var = first + secnd - (expectation**2)
    return mean, var

expected_time = np.zeros(len(times))
for i, n in enumerate(nodes):
    params = {4: {'descendant_tips': [2, 3, 4], 'weight': n[:3]/np.sum(n[:3])}}
    expected_time[i] = mixture_expect_and_var(params, cc)[0] * 1e4

This gives an equivalent set of plots which are identical to each other (as the mixture prior does not account for node spans). It looks like this:

image

hyanwong commented 12 months ago

It looks to me that for a 4 tip tree, the pattern (sloping diagonal) in the mixture prior is correct, but the expectation is being completely swamped by the effect of node span.

It should be relatively easy to fit a statistical model to find a decent set of predictors, I think (especially since we can generate as much data as we like). We should do the same for variance, of course.

As a start, we could fit a simple linear model between observed time ($T$), the expected time from the mixture prior for that node ($m$) and the span of the node ($s$). This might work if the mixture of CC priors gives a simple linear effect on top of the span (as hinted by the plots above)

$$ \log(T) \sim \log(m) + \log(s) + \log(m)*\log(s) $$

This would be easy to test as a statistical model on much larger numbers of samples.

I think @a-ignatieva has some predictions linking edge span (which she calls "duration") with time, so we might be able to use tha, although it's not a direct measure of node span.

hyanwong commented 12 months ago

Here's a useful plot, showing how we deviate from the mixture prior given the length of the nodes. As you can see, the longer the node span, the younger the observed time relative to what we expect.

We can repeat this for larger tree sequences, of course

image

fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5))
axes[0].hexbin(np.log(times)-np.log(expected_time), tot, yscale="log", norm=matplotlib.colors.LogNorm(1, 5000))
axes[0].set_ylabel("Node span (bp)")
axes[0].set_xlabel("Deviation from log expected time")
axes[1].scatter(np.log(times)-np.log(expected_time), tot, alpha=0.01)
axes[1].set_yscale("log")
axes[1].set_xlabel("Deviation from log expected time");
nspope commented 12 months ago

Very nice. Something else that might be useful: we can take a matrix $L$ of normalized spans per count of descendants (e.g. each row of $L$ corresponds to a particular count of descendants, each column to a node, and the entries are spans normalized such that rows sum to 1). Then, if $t$ is the vector of node ages and $e$ is vector of marginal expectations for each descendant count, we know that $Lt = e$. This holds for any raw moment (e.g. $e$ could contain $\mathbb{E}[t]$ or $\mathbb{E}[t^2]$ or $\mathbb{E}[\log t]$, and so on). This is an overdetermined system of equations, so it doesn't generate a node-specific prior outright, but it does provide a linear constraint. The constraint gets weaker the more nodes there are/the longer the sequence is; but could be applied to windows of a fixed length. And, the CLT should apply here, so that each dot product $\sumj L{ij} t_j$ is normally distributed with mean $e_i$.

hyanwong commented 12 months ago

Hmm, interesting thought. Also, if our problem is mostly to do not with the mean prior assigned to each node, but the variance, then we should try to do a similar plot with variance too. Here's one attempt, using data from simulated tree sequences of 100 samples:

image

If I've got the code right (below), then it looks like the variance problem is more serious, as you thought, @nspope . In particular, we estimate much too high a variance for long spans, and too low a variance (which is presumably worse) for nodes with short spans.

import tsdate
nm = {n: i for i, n in enumerate(tsdate.prior.PriorParams._fields)}
cc = tsdate.prior.ConditionalCoalescentTimes(1000)

def mixture_expect_and_var(mixture, cond_coal):
    """
    Return the expectation and variance of a coalescent mixture
    mixture is a dict of the form N:{'descendant_tips': [tips], 'weight': [weights]}
    """
    expectation = 0
    first = secnd = 0
    for N, tip_dict in mixture.items():
        # assert 1 not in tip_dict.descendant_tips
        mean = cond_coal[N][tip_dict["descendant_tips"], nm["mean"]]
        var = cond_coal[N][tip_dict["descendant_tips"], nm["var"]]
        # Mixture expectation
        expectation += np.sum(mean * tip_dict["weight"])
        # Mixture variance
        first += np.sum(var * tip_dict["weight"])
        secnd += np.sum(mean**2 * tip_dict["weight"])
    mean = expectation
    var = first + secnd - (expectation**2)
    return mean, var

import msprime
import numpy as np
import tqdm
reps = 10
n = 100
cc.add(n)
nstat = None
# Save into a massive matrix with 7 columns:
for ts in tqdm.tqdm(
    msprime.sim_ancestry(n, ploidy=1, population_size=1e4, sequence_length=5e7, recombination_rate=1e-8, num_replicates=reps)
):
    n_internal = ts.num_nodes - ts.num_samples
    node_arr = np.zeros((n_internal, ts.num_samples - 1))
    for tree in ts.trees():
        for u in tree.nodes():
            if tree.is_internal(u):
                node_arr[u-ts.num_samples, tree.num_samples(u) - 2] += tree.interval.span
    stat = np.array([
        ts.nodes_time[ts.num_samples:], # 0 actual time
        np.sum(node_arr, axis=1),       # 1 node span
        np.zeros(n_internal),           # 2 will be expected (mixture) time
        np.zeros(n_internal),           # 3 will be expected (mixture) variance
        np.zeros(n_internal),           # 4 will be mean ntips
        np.zeros(n_internal),           # 5 will be smallest nonzero span ntips
        np.zeros(n_internal),           # 6 will be largest nonzero span ntips
    ]).T
    ntips = np.arange(2, n+1)
    for u in np.arange(n_internal):
        weights = node_arr[u,:]/np.sum(node_arr[u,:])
        params = {n: {'descendant_tips': ntips, 'weight': weights}}
        mean_var = mixture_expect_and_var(params, cc)
        stat[u, 2] = mean_var[0] * 1e4
        stat[u, 3] = mean_var[1] * 1e4 * 1e4
        stat[u, 4] = np.sum(ntips * weights)
        stat[u, 5] = np.flatnonzero(node_arr[u,:])[0] + 2
        stat[u, 6] = np.flatnonzero(node_arr[u,:])[-1] + 2
    if nstat is None:
        nstat = stat
    else:
        nstat = np.vstack((nstat, stat))

fig, axes = plt.subplots(1, 2, sharey=True, figsize=(10, 5))
axes[0].hexbin(np.log(nstat[:,0])-np.log(nstat[:,2]), nstat[:,1], yscale="log", norm=matplotlib.colors.LogNorm(1, 100))
#axes[0].set_ylabel("Node span (bp)")
axes[0].set_ylabel("Node span")
axes[0].set_xlabel("Deviation from log expected time")

for lower, upper in itertools.pairwise(np.quantile(nstat[:,1], np.linspace(0, 1, 100))):
    use = np.logical_and(nstat[:,1] >= lower, nstat[:,1] < upper)
    var = np.var(nstat[:,0][use], ddof=1)
    expected_var = np.mean(nstat[:,3][use])
    axes[1].plot(np.log(var)-np.log(expected_var), (lower+upper)/2, "bo")
axes[1].set_xlabel("Deviation from log expected variance in time");
hyanwong commented 12 months ago

And here's the same for the average number of descendant samples on the Y axis, rather than the node span (i.e. using nstat[:,4] rather than nstat[:,1] in the code above). The mean time is doing OK when we bin by this measure. As expected, if we repeat this on a ts with no recombination (data not shown) then the means and variances are distributed evenly around the X=0 mark.

image

hyanwong commented 12 months ago

So far we have been calculating the mixtures by the mean number of descendant samples weighted by the span. But I wonder if it is better to use the log of the span as a weight instead, or even not weight at all.

nspope commented 12 months ago

I've tried without weighting by span at all -- this still introduces artefacts. I think the fact that there is a dependence b/w span and age means that you'd need a correction that depends on both of these (like in the linear model you've spelled out above).

hyanwong commented 12 months ago

I think that's right, but a linear model that incorporates (log) length of the node might still fit better with a log-weighted (or unweighted) mixture component. By the way, from a few statistical fits I have tried, it seems that we don't really need an interaction term between log(span) and mixture value.

Anyway, this is possibly getting into the weeds a little, if we are happy to use a global prior!

nspope commented 12 months ago

Well, I think this is great to keep in mind for future improvements. For now, the global prior seems to work reasonably well. But, it'd be great come up with a way to generate the global prior without the very costly variance calculation for the mixture. I think, for example, if we just calculate the first moment (linear complexity in number of tips) for each node, then calculate the mean/variance across all nodes, we'll get a reasonable global prior. This would work for trees with 10s or 100s of thousands of tips, without interpolating.

hyanwong commented 12 months ago

FWIW, I have just tested the log-weighting, and it makes both the mean time and the variance in time fit better. The mixture variances cluster more closely to the observed variances, and the fit of expectations to observed times, both in a linear model with node span and a linear model without node span is better. So I think it is worth testing this on a full tsdate simulation anyway. There's no harm in weighting by the log of the span, and I think that nodes spans are more close to being exponentially distributed than linear. I assume there is some theory about the sizes of chunks when a line is repeatedly cut up at random positions - it feels somewhat exponential to me.

hyanwong commented 12 months ago

There's something I'm not quite understanding here, because my plots above imply that the mixture prior assigns too great a variance (and slightly too great an expected mean) for old, long nodes. But the plot at https://github.com/tskit-dev/tsdate/pull/257#issuecomment-1532446409 implies that we set the time of those old nodes too low when using the vanilla mixture prior. Have I got things back-to-front somehow?