tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
18 stars 7 forks source link

Investigate effect of polytomies / topological uncertainty on posteriors #359

Open hyanwong opened 5 months ago

hyanwong commented 5 months ago

Hannes had an interesting idea: how do polytomies affect the variation in posterior times for a node? We could test this by taking a known topology and collapsing some of the nodes into polytomies, then dating, and looking at how the posterior distribution of times of the component nodes compares to the posterior distribution estimated for the collapsed polytomy.

nspope commented 5 months ago

another consideration is that a polytomy implies a greater total edge span than the original binary topologies; which I'd think would introduce bias. IIRC we don't see a relationship between arity and bias, however.

hyanwong commented 2 months ago

Now that we have a decent routine to create polytomies (https://github.com/tskit-dev/tskit/discussions/2926), we can test out the effect of polytomies on dating. Here's an example, using the true topologies, without and with induced polytomies (where edges without a mutation on them are removed to make a polytomy, see the trees below the plot). It appears as if making polytomies like this biases the mutation dates to younger times as the sample size increases:

download

Example first tree (original, then with induced polytomies):

Screenshot 2024-04-15 at 15 06 23

FWIW, the pattern doesn't change much if we use the metadata-stored mutation times instead.

Click to reveal code to reproduce the plot above: ```python import itertools import collections import numpy as np def remove_edges(ts, edge_id_remove_list): edges_to_remove_by_child = collections.defaultdict(list) edge_id_remove_list = set(edge_id_remove_list) for remove_edge in edge_id_remove_list: e = ts.edge(remove_edge) edges_to_remove_by_child[e.child].append(e) # sort left-to-right for each child for k, v in edges_to_remove_by_child.items(): edges_to_remove_by_child[k] = sorted(v, key=lambda e: e.left) # check no overlaps for e1, e2 in zip(edges_to_remove_by_child[k], edges_to_remove_by_child[k][1:]): assert e1.right <= e2.left # Sanity check: this means the topmost node will deal with modified edges left at the end assert ts.edge(-1).parent not in edges_to_remove_by_child new_edges = collections.defaultdict(list) tables = ts.dump_tables() tables.edges.clear() samples = set(ts.samples()) # Edges are sorted by parent time, youngest first, so we can iterate over # nodes-as-parents visiting children before parents by using itertools.groupby for parent_id, ts_edges in itertools.groupby(ts.edges(), lambda e: e.parent): # Iterate through the ts edges *plus* the polytomy edges we created in previous steps. # This allows us to re-edit polytomy edges when the edges_to_remove are stacked edges = list(ts_edges) if parent_id in new_edges: edges += new_edges.pop(parent_id) if parent_id in edges_to_remove_by_child: for e in edges: assert parent_id == e.parent l = -1 if e.id in edge_id_remove_list: continue # NB: we go left to right along the target edges, reducing edge e as required for target_edge in edges_to_remove_by_child[parent_id]: # As we go along the target_edges, gradually split e into chunks. # If edge e is in the target_edge region, change the edge parent assert target_edge.left > l l = target_edge.left if e.left >= target_edge.right: # This target edge is entirely to the LHS of edge e, with no overlap continue elif e.right <= target_edge.left: # This target edge is entirely to the RHS of edge e with no overlap. # Since target edges are sorted by left coord, all other target edges # are to RHS too, and we are finished dealing with edge e tables.edges.append(e) e = None break else: # Edge e must overlap with current target edge somehow if e.left < target_edge.left: # Edge had region to LHS of target # Add the left hand section (change the edge right coord) tables.edges.add_row(left=e.left, right=target_edge.left, parent=e.parent, child=e.child) e = e.replace(left=target_edge.left) if e.right > target_edge.right: # Edge continues after RHS of target assert e.left < target_edge.right new_edges[target_edge.parent].append( e.replace(right=target_edge.right, parent=target_edge.parent) ) e = e.replace(left=target_edge.right) else: # No more of edge e to RHS assert e.left < e.right new_edges[target_edge.parent].append(e.replace(parent=target_edge.parent)) e = None break if e is not None: # Need to add any remaining regions of edge back in tables.edges.append(e) else: # NB: sanity check at top means that the oldest node will have no edges above, # so the last iteration should hit this branch for e in edges: if e.id not in edge_id_remove_list: tables.edges.append(e) assert len(new_edges) == 0 tables.sort() return tables.tree_sequence() def unsupported_edges(ts, per_interval=False): """ Return the internal edges that are unsupported by a mutation. If ``per_interval`` is True, each interval needs to be supported, otherwise, a mutation on an edge (even if there are multiple intervals per edge) will result in all intervals on that edge being treated as supported. """ edges_to_remove = np.ones(ts.num_edges, dtype="bool") edges_to_remove[[m.edge for m in ts.mutations()]] = False # We don't remove edges above samples edges_to_remove[np.isin(ts.edges_child, ts.samples())] = False if per_interval: return np.where(edges_to_remove)[0] else: keep = (edges_to_remove == False) for p, c in zip(ts.edges_parent[keep], ts.edges_child[keep]): edges_to_remove[np.logical_and(ts.edges_parent == p, ts.edges_child == c)] = False return np.where(edges_to_remove)[0] ########### from matplotlib import pyplot as plt import stdpopsim import tsdate print(f"Using tsdate {tsdate.__version__}") species = stdpopsim.get_species("HomSap") model = species.get_demographic_model("AmericanAdmixture_4B11") contig = species.get_contig("chr20", mutation_rate=model.mutation_rate, length_multiplier=0.1) engine = stdpopsim.get_engine("msprime") sizes = (1, 10, 100, 1000, 10000) fig, axes = plt.subplots(len(sizes), 2, figsize=(10, 5*len(sizes))) axes[0][0].set_title("True topology") axes[0][1].set_title("Topology with induced polytomies") axes[-1][0].set_xlabel("True mutation times") axes[-1][1].set_xlabel("True mutation times") for ax, s in zip(axes, sizes): samples = {'AFR': s, 'EUR': s, 'ASIA': s, 'ADMIX': s} ts = engine.simulate(model, contig, samples, seed=123) print(ts.num_trees, "trees,", ts.num_sites, "sites,", ts.num_edges, "edges") poly_ts = remove_edges(ts, unsupported_edges(ts)) print(poly_ts.num_edges, "edges in unresolved ts") # Check it is doing the right thing dated_ts = tsdate.variational_gamma(ts.simplify(), mutation_rate=model.mutation_rate, normalisation_intervals=100) dated_poly_ts = tsdate.variational_gamma(poly_ts.simplify(), mutation_rate=model.mutation_rate, normalisation_intervals=100) x = [max(m.time for m in s.mutations) for s in ts.sites()] y = [max(m.time for m in s.mutations) for s in dated_ts.sites()] y_poly = [max(m.time for m in s.mutations) for s in dated_poly_ts.sites()] ax[0].set_ylabel(f"Inferred times ({ts.num_samples} samples)") ax[0].hexbin(x, y, bins="log", xscale="log", yscale="log") ax[0].plot(np.logspace(1, 5), np.logspace(1, 5), "-", c="red") ax[1].hexbin(x, y_poly, bins="log", xscale="log", yscale="log") ax[1].plot(np.logspace(1, 5), np.logspace(1, 5), "-", c="red") ``` resulting in the plot above, and outputting: ``` Using tsdate 0.1.dev885+g36d81f4 7367 trees, 13103 sites, 22725 edges 30148 edges in unresolved ts 19649 trees, 31256 sites, 71580 edges 131054 edges in unresolved ts 48806 trees, 72224 sites, 190253 edges 417203 edges in unresolved ts 107988 trees, 154211 sites, 444165 edges 1222924 edges in unresolved ts 201490 trees, 284048 sites, 971433 edges ```
jeromekelleher commented 2 months ago

Nice!

nspope commented 2 months ago

This is great @hyanwong, thanks. I can think of a few things to try that might reduce bias-- will report back.

hyanwong commented 2 months ago

Thanks @nspope : from a few tests it appears as if the bias is less pronounced in tsinfer inferred tree sequences. Plots below - the right hand column is tsinfer on the same data:

download

hyanwong commented 2 months ago

As an aside, I wondered if reducing to the topology only present at each variable site would change the bias, but it doesn't seem to very much download

Code here Using the code above, plus ```python import numpy as np import scipy import tsinfer # Warning - this take a long time (e.g. 10 hours) its = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(ts), progress_monitor=True, num_threads=8) dated_its = tsdate.variational_gamma(its.simplify(filter_sites=False), mutation_rate=model.mutation_rate, normalisation_intervals=100) dated_reduced_ts = tsdate.variational_gamma( remove_edges(ts, unsupported_edges(ts)).simplify(reduce_to_site_topology=True), mutation_rate=model.mutation_rate, normalisation_intervals=100 ) fig, axes = plt.subplots(1, 4, figsize=(20, 5)) axes[0].set_title("True topology") axes[1].set_title("Topology with induced polytomies") axes[2].set_title("Topology reduced to variable_sites & polytomies") axes[3].set_title("Tsinferred") axes[0].set_xlabel("True mutation times") axes[1].set_xlabel("True mutation times") axes[2].set_xlabel("True mutation times") axes[3].set_xlabel("True mutation times") x = [max(m.time for m in s.mutations) for s in ts.sites()] y = [max(m.time for m in s.mutations) for s in dated_ts.sites()] y_poly = [max(m.time for m in s.mutations) for s in dated_poly_ts.sites()] y_red = np.array([ max((json.loads(m.metadata.decode())["mn"] for m in s.mutations), default = np.nan) for s in dated_reduced_ts.sites() ]) y_inf = np.array([ max((json.loads(m.metadata.decode())["mn"] for m in s.mutations), default = np.nan) for s in dated_its.sites() ]) axes[0].set_ylabel(f"Inferred times ({ts.num_samples} samples)") axes[0].hexbin(x, y, bins="log", xscale="log", yscale="log") axes[0].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red") bias = np.mean(np.log(y) - np.log(x)) rho = scipy.stats.spearmanr(np.log(x), np.log(y)).statistic axes[0].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}") axes[1].hexbin(x, y_poly, bins="log", xscale="log", yscale="log") axes[1].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red") bias = np.mean(np.log(y_poly) - np.log(x)) rho = scipy.stats.spearmanr(np.log(x), np.log(y_poly)).statistic axes[1].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}") axes[2].hexbin(x, y_red, bins="log", xscale="log", yscale="log") axes[2].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red") bias = np.mean(np.log(y_red) - np.log(x)) rho = scipy.stats.spearmanr(np.log(x), np.log(y_red)).statistic axes[2].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}") axes[3].hexbin(x, y_inf, bins="log", xscale="log", yscale="log") axes[3].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red") bias = np.nanmean(np.log(y_inf) - np.log(x)) rho = scipy.stats.spearmanr(np.log(x), np.log(y_inf), nan_policy="omit").statistic axes[3].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}") ```
nspope commented 2 months ago

Looking first at node ages ... the reason there's bias in dating nodes after introducing polytomies is because there's more mutational area than was in the original binary trees. E.g. we're increasing the total branch length, which means that when we match moments using segregating sites we end up shrinking the timescale.

To be a bit more precise: the current normalisation strategy calculates total edge area and total number of mutations, then rescales time such that the expected number of mutations matches the total number of mutations.

Instead, consider doing the following: for each tree, sample a path from a randomly selected leaf to the root. Only accumulate edge area and mutations on the sampled paths. This should be unbiased, because the "path length" is the same regardless of the presence of polytomies. In fact, we can do this sampling deterministically, because the probability that a randomly selected path passes through a given edge is proportional to the number of samples subtended by that edge. E.g. we normalise as before but weight edges by the number of samples they subtend.

Using this alternative "path normalisation" strategy seems to greatly help with bias (1000 samples, 10 Mb):

norm_poly-nodes

This more-or-less carries over for mutations:

norm_poly-muts

hyanwong commented 2 months ago

Oh wow. This is amazing. Thanks Nate.

Does it cause any overcorrection problems for tsinferred tree sequences? I assume it shouldn't...

nspope commented 2 months ago

Another way to phrase this is that we're moment matching against a different summary statistic (rather than segregating sites), that is the expected number of differences between a single sample and the root. In my opinion this choice of summary statistic is a more conceptually straightforward way to measure time with mutational density.

nspope commented 2 months ago

I did a quick check on inferred simulated tree sequences -- the original routine was more or less unbiased (as Yan observed above) and the new routine does about the same. Would be interesting to compare the two on real data. Regardless, this new routine seems like the right approach.

hyanwong commented 2 months ago

the new routine does about the same

That's great.

Regardless, this new routine seems like the right approach.

Absolutely. We should go with the new approach. I wonder how both approaches perform on reinference? I can check this once there's instructions for how to run the new version.

nspope commented 2 months ago

The API is exactly the same, with the new normalisation scheme used by default. The old normalisation scheme can be toggled if you pass match_segregating_sites=True to date.

hyanwong commented 2 months ago

Great, thanks for the info. Is it currently much slower than the old version? It seems maybe not?

nspope commented 2 months ago

It shouldn't be, but would be good to check (if you enable logging it'll print out time spent during normalisation). There's an additional pass over edges, but this is done in numba. It might add few minutes on GEL or UKBB sized data, so would be good to enable logging there to get a sense for the overhead.

nspope commented 2 months ago

I wonder how both approaches perform on reinference

Actually, I don't think it'll change reinference at all -- ancestor building just uses the ordering of mutations, right? Normalisation won't change the order, just the inter-node time differences.