tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
19 stars 10 forks source link

Returned posteriors incorrect for 2 tip tree with mutation_rate=None #230

Open hyanwong opened 1 year ago

hyanwong commented 1 year ago

For example:

import tsdate
import tskit
tables = tskit.Tree.generate_balanced(2).tree_sequence.dump_tables()
ts = tables.tree_sequence()
grid = tsdate.build_prior_grid(ts, Ne=40, prior_distribution="gamma", timepoints=2000)
ts, posteriors = tsdate.date(ts, mutation_rate=None, priors=grid, return_posteriors=True)
print(posteriors[2])

All the posterior probabilities are one in this case. I think that's wrong

hyanwong commented 1 year ago

Actually, this is not a bug in the mutation_rate=None path, it only when we date a 2 tip tree. The following code with a 3 tip tree seems to work fine.

import tsdate
import tskit
tables = tskit.Tree.generate_comb(3).tree_sequence.dump_tables()
ts = tables.tree_sequence()
display(ts.draw_svg())
grid = tsdate.build_prior_grid(ts, Ne=40, prior_distribution="gamma", timepoints=20)
ts, posteriors = tsdate.date(ts, mutation_rate=None, priors=grid, return_posteriors=True)
print(posteriors[ts.first().root])
hyanwong commented 1 year ago

I think this is the same bug that we identified previously, although I can't find the issue just now. The problem is that the priors do not properly account for tips at time 0. Here's a trivial example:

import tsdate
import tskit
ts = tskit.Tree.generate_comb(2).tree_sequence
grid = tsdate.build_prior_grid(ts, Ne=40, prior_distribution="gamma", timepoints=20)
print(grid[2]) # Gives array([0., 1., 1., ..., 1., 1., 1.]) rather than a proper distribution

The reason for this is that we place a zero as the first element of the priors array here: https://github.com/tskit-dev/tsdate/blob/c9db1d918ec179dded119ce5d26a6aba33c1358a/tsdate/prior.py#L979

This means that any node whose children have their entire weight at zero get probabilities of (essentially) zero passed to them from their children.

I think the fix should be to consider the probabilities as being between each of the timepoints, and for the last probability to cover the probability of the node landing in the timeslice from timepoints[-1] ... infinity. However, this is rather a major change, and would need careful testing to check that it actually improved tsdate performance overall.

hyanwong commented 1 year ago

In might be the same issue being picked up in https://github.com/hyanwong/molecular-sample-dating/blob/main/notebooks/uniform_prior.ipynb

hyanwong commented 1 year ago

Another place that we ignore the zeroth timeslice is when normalising: https://github.com/tskit-dev/tsdate/blob/e9bd0982fc4a6cbe9afad8f3a867cb7b8bcc37cb/tsdate/base.py#L135

nspope commented 1 year ago

I actually don't think this is a bug or related to having zero mass at the first timepoint-- it is expected behavior given how the default timegrid is constructed from quantiles. That is, the default timegrid takes equally spaced quantiles of the prior as the timepoints (here), concatenates these quantiles across nodes (with some thinning), then assigns mass that is the difference in the prior CDF between adjacent timepoints (here). When there's only one internal node, there's only one prior distribution from which to take quantiles, and because the quantiles are equally spaced the CDF differences are all the same (1/(num_timepoints - 1)). In other words, all the information from the prior in the two-tip case is in the spacing of the timepoints, not in the prior probabilities themselves.

For example, with a non-default timegrid, we don't get uniform probabilities because the timepoints don't correspond to equally-spaced quantiles:

import tsdate
import tskit
import numpy as np
ts = tskit.Tree.generate_comb(2).tree_sequence
prior = tsdate.prior.MixturePrior(ts, prior_distribution="lognorm")
grid_default = prior.make_discretized_prior(population_size=40, timepoints=6)
grid_custom = prior.make_discretized_prior(
   population_size=40, timepoints=np.array([0, 0.1, 1, 10, 100, 1000])
)
print(grid_default[2])
# [0. 1. 1. 1. 1. 1. 1.]
print(grid_custom[2])
# [0.00000000e+00 1.44197669e-02 1.00000000e+00 5.18212919e-01 1.12178648e-03 2.08604998e-09]

It may seem a bit pathological, but it seems to work well enough in practice. An alternative, which wouldn't have this behavior, would be to use the prior PDF at each timepoint rather than the difference in CDF between adjacent timepoints. This would, in effect, be defining the discretized prior in terms of discrete atoms rather than time intervals -- and this is consistent with how the likelihoods are calculated.

I think this would give very similar results (to the current approach) when there's a lot of timepoints. I do remember seeing some discrepancies (between PDF v CDF priors) while working on a quadrature scheme, but there could be other issues at play there.

nspope commented 1 year ago

Conceptually, I think it's cleaner to have both the prior and posterior be defined over discrete timepoints (all this would require is calculating the prior probs from PDF rather than CDF). Then, to extend the algorithm to timeslices, we'd use this discretization with a quadrature scheme. However, the current setup seems to work pretty well, and is super convenient for the variable Ne prior.