Open awohns opened 2 years ago
We can pass the sample node to date as a "datable node" here:
but we will need to specify a mean and variance for the distribution somehow, or modify the contents of the returned NodeGridValues
object.
I think the best thing would be to set any nodes as "datable" if they have a non-zero variance (rather than if they are not samples). Then we simple need to figure out how to give samples a non-zero variance (we can take the mean for the prior as the time of the node in the tree sequence).
It looks like the base_priors.get_mixture_prior_params
returns an array of alpha and beta params which are used in
And which contains nan
for fixed nodes. It think we should modify this to provide alpha and beta values which specify a mean of the sample time and a variance of 0 for fixed nodes. I think this is only possible for the lognormal distribution, however: there does not exist a gamma distribution that has an arbitrary mean but zero variance (although we could approximate it with a minuscule variance, e.g. 1e-20).
I have made some progress with https://github.com/tskit-dev/tsdate/pull/214/commits/786cf12a10d126031ca8b67a251abe0050a601e0. Here is some code to test:
import tsinfer
import msprime
import tskit
import tsdate
import numpy as np
import matplotlib.pyplot as plt
Ne = 10000
samples = [
msprime.SampleSet(2),
msprime.SampleSet(1, time=100),
]
mutated_ts = msprime.sim_ancestry(
samples=samples,
population_size=Ne,
sequence_length=2e4,
recombination_rate=0, # For testing, just have a single tree
random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)
def create_sampledata_with_individual_times(ts):
"""
The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
units for sites and individuals. This function adds individual times by hand
"""
# sampledata file with times-as-frequencies
sd = tsinfer.SampleData.from_tree_sequence(ts)
# Set individual times separately - warning: this mixes time units
# so that sites have TIME_UNCALIBRATED but individuals have meaningful times
individual_time = np.full(sd.num_individuals, -1)
for sample, node_id in zip(sd.samples(), ts.samples()):
if individual_time[sample.individual] >= 0:
assert individual_time[sample.individual] == ts.node(node_id).time
individual_time[sample.individual] = ts.node(node_id).time
assert np.all(individual_time >= 0)
sd = sd.copy()
sd.individuals_time[:] = individual_time
sd.finalise()
return sd
def set_times_for_historical_samples(ts):
"""
Use the times stored in the individuals metadata of an inferred tree sequence
to constrain the times.
"""
tables = ts.dump_tables()
tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
ts = tables.tree_sequence()
times = np.zeros(ts.num_nodes)
# set sample node times of historic samples
for node_id in ts.samples():
individual_id = ts.node(node_id).individual
if individual_id != tskit.NULL:
times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
tables.nodes.time = constrained_times
tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
tables.sort()
return tables.tree_sequence()
sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()
print(inferred_ts_w_times.node(5))
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=True, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)
This fails when truncating priors, however:
/usr/local/lib/python3.9/site-packages/tsdate/prior.py in _truncate_priors(ts, priors, progress)
1062 ):
1063 if index + 1 != len(truncate_nodes):
-> 1064 children_index = np.arange(parent_indices[index], parent_indices[index + 1])
1065 else:
1066 children_index = np.arange(parent_indices[index], ts.num_edges)
IndexError: index 3 is out of bounds for axis 0 with size 3
I can't quite figure out the logic in that function. Perhaps @awohns can talk me through it and we can see what is not working. It should be perfectly possible to truncate on the basis of a few fixed sample nodes.
We can test the pathway without truncation using the code above via
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:10})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)
With https://github.com/tskit-dev/tsdate/pull/214/commits/786cf12a10d126031ca8b67a251abe0050a601e0 this not complains about dangling nodes on the inside pass, which is correct, as the node corresponding to the sample-to-date will appear as if it is dangling.
tsdate/core.py in inside_pass(self, normalize, cache_inside, progress)
682 # Child appears fixed, or we have not visited it. Either our
683 # edge order is wrong (bug) or we have hit a dangling node
--> 684 raise ValueError(
685 "The input tree sequence includes "
686 "dangling nodes: please simplify it"
ValueError: The input tree sequence includes dangling nodes: please simplify it
The inside[edge.child]
array is full of np.nan
in the case of an undated sample node. I presume that we simply need to fill it with the appropriate values from the prior?
The key line is here, where we fill the inside either with np.nan
for a node of unknown date, or with the identity value (i.e. prob=1) for the fixed nodes:
inside = self.priors.clone_with_new_data( # store inside matrix values
grid_data=np.nan, fixed_data=self.lik.identity_constant
)
Note that my changes simply create a lognormal distribution (with a user-specified variance) for the prior on an undated sample node. If a more complicated prior is needed, I guess it can be created by hand. We can show an example of this in the docs.
Wow, with https://github.com/tskit-dev/tsdate/pull/214/commits/979f55c2f864d47145f7115fabd8aba2c9477479 it's almost working with the outside_maximization method. The only issue now is setting the times so that they are topologically constrained:
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization")
tsdate/core.py in constrain_ages_topo(ts, post_mn, eps, nodes_to_date, progress)
943 ):
944 if index + 1 != len(nodes_to_date):
--> 945 children_index = np.arange(parent_indices[index], parent_indices[index + 1])
946 else:
947 children_index = np.arange(parent_indices[index], ts.num_edges)
IndexError: index 3 is out of bounds for axis 0 with size 3
The only issue now is setting the times so that they are topologically constrained:
Fixed with https://github.com/tskit-dev/tsdate/pull/214/commits/da586445c3900ab360187aff4f85252e6260a8cd and https://github.com/tskit-dev/tsdate/pull/214/commits/02a9b674da6911efe8d2c1c57331bda28e30bbd7
The current PR https://github.com/tskit-dev/tsdate/pull/214 works, but only with the outside maximisation method, which won't return posteriors.
Here's what we get when trying the inside-outside:
import tsinfer
import msprime
import tskit
import tsdate
import numpy as np
Ne = 10000
samples = [
msprime.SampleSet(2),
msprime.SampleSet(1, time=100),
]
mutated_ts = msprime.sim_ancestry(
samples=samples,
population_size=Ne,
sequence_length=2e4,
recombination_rate=0, # For testing, just have a single tree
random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)
def create_sampledata_with_individual_times(ts):
"""
The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
units for sites and individuals. This function adds individual times by hand
"""
# sampledata file with times-as-frequencies
sd = tsinfer.SampleData.from_tree_sequence(ts)
# Set individual times separately - warning: this mixes time units
# so that sites have TIME_UNCALIBRATED but individuals have meaningful times
individual_time = np.full(sd.num_individuals, -1)
for sample, node_id in zip(sd.samples(), ts.samples()):
if individual_time[sample.individual] >= 0:
assert individual_time[sample.individual] == ts.node(node_id).time
individual_time[sample.individual] = ts.node(node_id).time
assert np.all(individual_time >= 0)
sd = sd.copy()
sd.individuals_time[:] = individual_time
sd.finalise()
return sd
def set_times_for_historical_samples(ts):
"""
Use the times stored in the individuals metadata of an inferred tree sequence
to constrain the times.
"""
tables = ts.dump_tables()
tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
ts = tables.tree_sequence()
times = np.zeros(ts.num_nodes)
# set sample node times of historic samples
for node_id in ts.samples():
individual_id = ts.node(node_id).individual
if individual_id != tskit.NULL:
times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
# Just need to make the ts consistent
constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
tables.nodes.time = constrained_times
tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
tables.sort()
return tables.tree_sequence()
sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization", return_posteriors=True) # WORKS!
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, return_posteriors=True) # FAILS
/usr/local/lib/python3.9/site-packages/tsdate/core.py in outside_pass(self, normalize, ignore_oldest_root, progress, probability_space_returned)
792
793 # vv[0] = 0 # Seems a hack: internal nodes should be allowed at time 0
--> 794 assert self.norm[edge.child] > self.lik.null_constant
795 outside[child] = self.lik.reduce(val, self.norm[child])
796 if normalize:
AssertionError:
It's failing because self.norm[edge.child] is nan
in this case. If we can fix this, I think we should have a working computational molecular dating method. Any ideas how to get the outside pass working @awohns ? Can we simply set the normalisation constant to 1 here?
Can we simply set the normalisation constant to 1 here?
https://github.com/tskit-dev/tsdate/pull/214/commits/974038dd54e08b037d71167c563f7e2d32182298 sets the normalization constant to unity for non fixed leaf nodes.
However, I'm having second thoughts about the sum_to_unity
function. Since we have different width time bins, I suspect that we want the cumulative sum to be one, right? We can't simply sum up all the probabilities for the grid slices.
It's technically working but there's a bug, I think. I reckon the following should give a relatively flat prior for node 5:
import tsdate
import matplotlib.pyplot as plt
variance = 1e8 # a big number
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:variance})
prior.force_probability_space("linear")
print(prior[5])
plt.stairs(prior[5][:-1], prior.timepoints)
It doesn't for me. The variance logic must be wrong, I think.
I reckon the following should give a relatively flat prior for node 5:
Here's some code to discuss:
import scipy.stats
import numpy as np
def lognorm_approx(mean, var):
"""
alpha is mean of underlying normal distribution
beta is variance of underlying normal distribution
"""
beta = np.log(var / (mean ** 2) + 1)
alpha = np.log(mean) - 0.5 * beta
return alpha, beta
def shape_scale_from_mean_var(mean, var):
a, b = lognorm_approx(mean, var)
return np.sqrt(b), np.exp(a)
timepoints = np.array(
[ 0. , 422.2655105 , 596.44205246, 752.38907357,
904.41565448, 1058.56279063, 1218.65683564, 1387.84069925,
1569.19515032, 1766.1152649 , 1982.64596477, 2223.88220297,
2496.53641723, 2809.82966009, 3177.00084615, 3618.05796915,
4165.24283193, 4875.1600594 , 5860.22741522, 6725.01537605,
7392.4550719 , 8583.20477957, 10441.70306644, 11706.4971667 ,
13503.29402783, 14535.1094196 , 15645.51778133, 17737.96380629,
20558.5674799 , 23651.27728082, 27051.19988133, 29023.72411745,
31155.88653047, 33497.40292709, 36116.46074264, 39112.19711119,
42638.72817761, 46956.92775847, 52564.36877884, 60607.17771601,
74895.98339996])
#timepoints = np.arange(16) * 5000
print(timepoints)
shape, scale = shape_scale_from_mean_var(10000, 1e8)
cdf_func = scipy.stats.lognorm.cdf
prior_node = cdf_func(timepoints, shape, scale=scale)
print("cdf", prior_node)
#prior_node = np.divide(prior_node, np.max(prior_node))
p = np.concatenate([np.array([0]), np.diff(prior_node)])
print("pdf (prior)", p)
import matplotlib.pyplot as plt
plt.stairs(p[:-1]/np.diff(timepoints), timepoints)
plt.xscale("log")
Currently
tsdate
only allows sample nodes which have a known date. We want to rewiretsdate
so sample nodes can have an unknown date, allowing for "molecular sampling"