hyanwong / molecular-sample-dating

Methods for dating ancient samples using tree sequence methods
0 stars 0 forks source link

Rewire tsdate to allow nonfixed sample nodes #1

Open awohns opened 2 years ago

awohns commented 2 years ago

Currently tsdate only allows sample nodes which have a known date. We want to rewire tsdate so sample nodes can have an unknown date, allowing for "molecular sampling"

hyanwong commented 2 years ago

We can pass the sample node to date as a "datable node" here:

https://github.com/tskit-dev/tsdate/blob/53add5684443388f06e743d39edf05a5e3f33149/tsdate/prior.py#L979

but we will need to specify a mean and variance for the distribution somehow, or modify the contents of the returned NodeGridValues object.

I think the best thing would be to set any nodes as "datable" if they have a non-zero variance (rather than if they are not samples). Then we simple need to figure out how to give samples a non-zero variance (we can take the mean for the prior as the time of the node in the tree sequence).

It looks like the base_priors.get_mixture_prior_params returns an array of alpha and beta params which are used in

https://github.com/tskit-dev/tsdate/blob/53add5684443388f06e743d39edf05a5e3f33149/tsdate/prior.py#L956

And which contains nan for fixed nodes. It think we should modify this to provide alpha and beta values which specify a mean of the sample time and a variance of 0 for fixed nodes. I think this is only possible for the lognormal distribution, however: there does not exist a gamma distribution that has an arbitrary mean but zero variance (although we could approximate it with a minuscule variance, e.g. 1e-20).

hyanwong commented 2 years ago

I have made some progress with https://github.com/tskit-dev/tsdate/pull/214/commits/786cf12a10d126031ca8b67a251abe0050a601e0. Here is some code to test:

import tsinfer
import msprime
import tskit
import tsdate

import numpy as np
import matplotlib.pyplot as plt

Ne = 10000
samples = [
        msprime.SampleSet(2),
        msprime.SampleSet(1, time=100),
    ]
mutated_ts = msprime.sim_ancestry(
    samples=samples,
    population_size=Ne,
    sequence_length=2e4,
    recombination_rate=0, # For testing, just have a single tree
    random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)

def create_sampledata_with_individual_times(ts):
    """
    The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
    units for sites and individuals. This function adds individual times by hand
    """
    # sampledata file with times-as-frequencies
    sd = tsinfer.SampleData.from_tree_sequence(ts)
    # Set individual times separately - warning: this mixes time units
    # so that sites have TIME_UNCALIBRATED but individuals have meaningful times
    individual_time = np.full(sd.num_individuals, -1)
    for sample, node_id in zip(sd.samples(), ts.samples()):
        if individual_time[sample.individual] >= 0:
            assert individual_time[sample.individual] == ts.node(node_id).time
        individual_time[sample.individual] = ts.node(node_id).time
    assert np.all(individual_time >= 0)
    sd = sd.copy()
    sd.individuals_time[:] = individual_time
    sd.finalise()
    return sd

def set_times_for_historical_samples(ts):
    """
    Use the times stored in the individuals metadata of an inferred tree sequence
    to constrain the times.
    """
    tables = ts.dump_tables()
    tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
    ts = tables.tree_sequence()
    times = np.zeros(ts.num_nodes)
    # set sample node times of historic samples
    for node_id in ts.samples():
        individual_id = ts.node(node_id).individual
        if individual_id != tskit.NULL:
            times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
    constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
    tables.nodes.time = constrained_times
    tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
    tables.sort()
    return tables.tree_sequence()

sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()

print(inferred_ts_w_times.node(5))

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=True, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)

This fails when truncating priors, however:

/usr/local/lib/python3.9/site-packages/tsdate/prior.py in _truncate_priors(ts, priors, progress)
   1062     ):
   1063         if index + 1 != len(truncate_nodes):
-> 1064             children_index = np.arange(parent_indices[index], parent_indices[index + 1])
   1065         else:
   1066             children_index = np.arange(parent_indices[index], ts.num_edges)

IndexError: index 3 is out of bounds for axis 0 with size 3

I can't quite figure out the logic in that function. Perhaps @awohns can talk me through it and we can see what is not working. It should be perfectly possible to truncate on the basis of a few fixed sample nodes.

hyanwong commented 2 years ago

We can test the pathway without truncation using the code above via

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:10})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)

With https://github.com/tskit-dev/tsdate/pull/214/commits/786cf12a10d126031ca8b67a251abe0050a601e0 this not complains about dangling nodes on the inside pass, which is correct, as the node corresponding to the sample-to-date will appear as if it is dangling.

tsdate/core.py in inside_pass(self, normalize, cache_inside, progress)
    682                         # Child appears fixed, or we have not visited it. Either our
    683                         # edge order is wrong (bug) or we have hit a dangling node
--> 684                         raise ValueError(
    685                             "The input tree sequence includes "
    686                             "dangling nodes: please simplify it"

ValueError: The input tree sequence includes dangling nodes: please simplify it

The inside[edge.child] array is full of np.nan in the case of an undated sample node. I presume that we simply need to fill it with the appropriate values from the prior?

The key line is here, where we fill the inside either with np.nan for a node of unknown date, or with the identity value (i.e. prob=1) for the fixed nodes:

        inside = self.priors.clone_with_new_data(  # store inside matrix values
            grid_data=np.nan, fixed_data=self.lik.identity_constant
        )
hyanwong commented 2 years ago

Note that my changes simply create a lognormal distribution (with a user-specified variance) for the prior on an undated sample node. If a more complicated prior is needed, I guess it can be created by hand. We can show an example of this in the docs.

hyanwong commented 2 years ago

Wow, with https://github.com/tskit-dev/tsdate/pull/214/commits/979f55c2f864d47145f7115fabd8aba2c9477479 it's almost working with the outside_maximization method. The only issue now is setting the times so that they are topologically constrained:

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization")
tsdate/core.py in constrain_ages_topo(ts, post_mn, eps, nodes_to_date, progress)
    943     ):
    944         if index + 1 != len(nodes_to_date):
--> 945             children_index = np.arange(parent_indices[index], parent_indices[index + 1])
    946         else:
    947             children_index = np.arange(parent_indices[index], ts.num_edges)

IndexError: index 3 is out of bounds for axis 0 with size 3
hyanwong commented 2 years ago

The only issue now is setting the times so that they are topologically constrained:

Fixed with https://github.com/tskit-dev/tsdate/pull/214/commits/da586445c3900ab360187aff4f85252e6260a8cd and https://github.com/tskit-dev/tsdate/pull/214/commits/02a9b674da6911efe8d2c1c57331bda28e30bbd7

hyanwong commented 2 years ago

The current PR https://github.com/tskit-dev/tsdate/pull/214 works, but only with the outside maximisation method, which won't return posteriors.

Here's what we get when trying the inside-outside:

import tsinfer
import msprime
import tskit
import tsdate

import numpy as np

Ne = 10000
samples = [
        msprime.SampleSet(2),
        msprime.SampleSet(1, time=100),
    ]
mutated_ts = msprime.sim_ancestry(
    samples=samples,
    population_size=Ne,
    sequence_length=2e4,
    recombination_rate=0, # For testing, just have a single tree
    random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)

def create_sampledata_with_individual_times(ts):
    """
    The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
    units for sites and individuals. This function adds individual times by hand
    """
    # sampledata file with times-as-frequencies
    sd = tsinfer.SampleData.from_tree_sequence(ts)
    # Set individual times separately - warning: this mixes time units
    # so that sites have TIME_UNCALIBRATED but individuals have meaningful times
    individual_time = np.full(sd.num_individuals, -1)
    for sample, node_id in zip(sd.samples(), ts.samples()):
        if individual_time[sample.individual] >= 0:
            assert individual_time[sample.individual] == ts.node(node_id).time
        individual_time[sample.individual] = ts.node(node_id).time
    assert np.all(individual_time >= 0)
    sd = sd.copy()
    sd.individuals_time[:] = individual_time
    sd.finalise()
    return sd

def set_times_for_historical_samples(ts):
    """
    Use the times stored in the individuals metadata of an inferred tree sequence
    to constrain the times.
    """
    tables = ts.dump_tables()
    tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
    ts = tables.tree_sequence()
    times = np.zeros(ts.num_nodes)
    # set sample node times of historic samples
    for node_id in ts.samples():
        individual_id = ts.node(node_id).individual
        if individual_id != tskit.NULL:
            times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
    # Just need to make the ts consistent
    constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
    tables.nodes.time = constrained_times
    tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
    tables.sort()
    return tables.tree_sequence()

sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization", return_posteriors=True)  # WORKS!
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, return_posteriors=True)  # FAILS
/usr/local/lib/python3.9/site-packages/tsdate/core.py in outside_pass(self, normalize, ignore_oldest_root, progress, probability_space_returned)
    792 
    793             # vv[0] = 0  # Seems a hack: internal nodes should be allowed at time 0
--> 794             assert self.norm[edge.child] > self.lik.null_constant
    795             outside[child] = self.lik.reduce(val, self.norm[child])
    796             if normalize:

AssertionError: 

It's failing because self.norm[edge.child] is nan in this case. If we can fix this, I think we should have a working computational molecular dating method. Any ideas how to get the outside pass working @awohns ? Can we simply set the normalisation constant to 1 here?

hyanwong commented 2 years ago

Can we simply set the normalisation constant to 1 here?

https://github.com/tskit-dev/tsdate/pull/214/commits/974038dd54e08b037d71167c563f7e2d32182298 sets the normalization constant to unity for non fixed leaf nodes.

However, I'm having second thoughts about the sum_to_unity function. Since we have different width time bins, I suspect that we want the cumulative sum to be one, right? We can't simply sum up all the probabilities for the grid slices.

hyanwong commented 2 years ago

It's technically working but there's a bug, I think. I reckon the following should give a relatively flat prior for node 5:

import tsdate
import matplotlib.pyplot as plt
variance = 1e8  # a big number
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:variance})
prior.force_probability_space("linear")
print(prior[5])
plt.stairs(prior[5][:-1], prior.timepoints)

It doesn't for me. The variance logic must be wrong, I think.

hyanwong commented 2 years ago

I reckon the following should give a relatively flat prior for node 5:

Here's some code to discuss:

import scipy.stats
import numpy as np

def lognorm_approx(mean, var):
    """
    alpha is mean of underlying normal distribution
    beta is variance of underlying normal distribution
    """
    beta = np.log(var / (mean ** 2) + 1)
    alpha = np.log(mean) - 0.5 * beta
    return alpha, beta

def shape_scale_from_mean_var(mean, var):
            a, b = lognorm_approx(mean, var)
            return np.sqrt(b), np.exp(a)

timepoints = np.array(
    [    0.        ,   422.2655105 ,   596.44205246,   752.38907357,
         904.41565448,  1058.56279063,  1218.65683564,  1387.84069925,
        1569.19515032,  1766.1152649 ,  1982.64596477,  2223.88220297,
        2496.53641723,  2809.82966009,  3177.00084615,  3618.05796915,
        4165.24283193,  4875.1600594 ,  5860.22741522,  6725.01537605,
        7392.4550719 ,  8583.20477957, 10441.70306644, 11706.4971667 ,
       13503.29402783, 14535.1094196 , 15645.51778133, 17737.96380629,
       20558.5674799 , 23651.27728082, 27051.19988133, 29023.72411745,
       31155.88653047, 33497.40292709, 36116.46074264, 39112.19711119,
       42638.72817761, 46956.92775847, 52564.36877884, 60607.17771601,
       74895.98339996])
#timepoints = np.arange(16) * 5000
print(timepoints)

shape, scale = shape_scale_from_mean_var(10000, 1e8)
cdf_func = scipy.stats.lognorm.cdf
prior_node = cdf_func(timepoints, shape, scale=scale)
print("cdf", prior_node)
#prior_node = np.divide(prior_node, np.max(prior_node))
p = np.concatenate([np.array([0]), np.diff(prior_node)])
print("pdf (prior)", p)

import matplotlib.pyplot as plt
plt.stairs(p[:-1]/np.diff(timepoints), timepoints)
plt.xscale("log")

image