tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
19 stars 10 forks source link

How should I understand the node age gamma distributions? #420

Closed YunDeng98 closed 3 months ago

YunDeng98 commented 3 months ago

Hi developers, I've been having great experience with tsdate 0.2.0 so far. There is one thing which I don't fully understand: associated with each node there is a gamma distribution (with mean and variance in the metadata). This distribution should stand for the estimation uncertainty in the algorithm about node ages. I am personally interested in evaluating the goodness of the uncertainty (as to whether the distribution is too wide or too narrow). As such, I simulated with high mutation rate (4 times recombination rate) and feed the true topology into tsdate. For each node, I computed the CDF of the true node age in the associated gamma distribution. Ideally it should be uniform between 0 and 1 if the distribution is "at the right width". Here is the code:

def extract_gamma_parameters(ts):
    alphas = []
    thetas = []
    for n in ts.nodes():
        if n.time > 0:
            mn = n.metadata['mn']
            vr = n.metadata['vr']
            alphas.append(vr/mn)
            thetas.append(mn*mn/vr)
    return alphas, thetas

def compute_quantiles(alphas, thetas, times):
    quantiles = []
    n = len(alphas)
    for i in range(n):
        quantiles.append(stats.gamma.cdf(times[i], a=alphas[i], scale=thetas[i]))
    return quantiles

def get_all_node_ages(ts):
    times = []
    for n in ts.nodes():
        if n.time > 0:
            times.append(n.time)
    return np.array(times)

def kl_divergence(total_counts):
    # Normalizing total_counts to get a probability distribution
    p = total_counts / np.sum(total_counts)

    # Creating a uniform distribution
    q = np.ones_like(p) / len(p)

    # Calculating the KL divergence
    kl_div = np.sum(p * np.log(p / q))

    return kl_div

dense_ts = msprime.sim_ancestry(samples=50, population_size=1e4, sequence_length=1e7, recombination_rate=1.2e-8, random_seed=1000)
dense_mts = msprime.sim_mutations(dense_ts, rate=4.8e-8, random_seed=1000)
tsdate_dense_ts = tsdate.date(dense_mts, mutation_rate=4.8e-8)
true_dense_node_times = get_all_node_ages(dense_mts)
alphas, thetas = extract_gamma_parameters(tsdate_dense_ts)
tsdate_dense_quantiles = compute_quantiles(alphas, thetas, true_dense_node_times)
tsdate_dense_rank_density, _ = np.histogram(tsdate_dense_quantiles, bins=np.linspace(0, 1, 101));
plt.figure(figsize=[6, 6]);
plt.scatter(np.arange(1, 101), np.log10(tsdate_dense_rank_density), color='orange', linewidth=1);
plt.axhline(y=np.log10(np.mean(tsdate_dense_rank_density)), color='black', linestyle='--', linewidth=3);
plt.ylim([0.5, 4]);
kld = kl_divergence(tsdate_dense_rank_density)
plt.text(35, 3.5, f"KLD={kld:.3f}", fontsize=15);

and the result:

image

which is of a weird shape. I wonder if the developers have any ideas about this. Thanks a lot!

nspope commented 3 months ago

Hi Yun!

        mn = n.metadata['mn']
        vr = n.metadata['vr']
        alphas.append(vr/mn)
        thetas.append(mn*mn/vr)

I think you want $\alpha = m^2 / v$ and $\theta = v / m$? So that $m = \alpha \theta$ and $v = \alpha \theta^2$.

But in general, yes, the actual coverage is going to be less than the nominal coverage for a given interval width (the posterior will be too concentrated), even on true trees. AFAICT this is an issue with expectation propagation: it generally does a great job at matching the mean, but underestimates the variance. Increasing max_iterations can help, as it takes longer for higher order moments to converge.

nspope commented 3 months ago

Here's your fig with that change to $\alpha$ and $\theta$:

yun_test

YunDeng98 commented 3 months ago

Thanks a lot! Sorry for the coding bug of mine.