petrelharp / num_edges

0 stars 0 forks source link

Metrics to compare effectiveness of edge extending #1

Open hyanwong opened 2 years ago

hyanwong commented 2 years ago

As discussed, it would be really useful to be able to compare (say) a ground truth tree sequence with unary-when-coalescent nodes left in, with the same version either inferred using tsinfer, or with edges extended using edge_extend(). We can throw around some ideas in this issue.

hyanwong commented 2 years ago

If we approach this by comparing the ancestral haplotypes that have been extended or inferred, I think there are 2 issues.

  1. How to identify corresponding ancestors in 2 tree sequences (say A and B), given that the tree sequences may be topologically different
  2. Once we have 2 putatively comparable ancestors, what is the metric we use to compare them (presumably something to do with the left and right extent of each edge)?

I don't have immediate suggestions for 2., but I do have a few ideas for 1.

To identify comparable ancestors, I don't think we can use times, so we are limited to using the topology. Therefore I think we need to identify them on the basis of the points in the genome at which they are coalescent nodes, and then look right and left from the coalescent regions at parts of the haplotype that are unary.

The comprehensive method would be to take a tree at a specific point in the genome, take a pair of samples (say 2 and 3), and assume that the MRCA of 2 and 3 is the same ancestor in both tree sequences A and B. The problem with this is that you would want to do this for every pair of samples in the tree. It the topologies differ, there would not be a 1-to-1 mapping between ancestors at that point in the genome.

A simple way around this would simply be to look at each (biallelic) site in turn, and identify a single ancestor for each site: the MRCA of all the samples with the derived mutation. Under infinite sites we could assume that this ancestor was the same between tree sequences A and B. If we are allowing multiple mutations in tsinfer, instead of the ancestor being the MRCA of all the samples with the derived mutation, we could simply e.g. take the oldest mutation at that site and assume that the node below it is the ancestor-to-be compared.

I'll work up some example code below.

hyanwong commented 2 years ago

Example of finding one pair of comparable nodes at each site

import operator

import numpy as np
import msprime
import tskit
import tsinfer

def simplify_keeping_unary_in_coal(ts, map_nodes=False):
    """
    Keep the unary regions of nodes that are coalescent at least someone in the tree seq
    Temporary hack until https://github.com/tskit-dev/tskit/issues/2127 is addressed
    """
    tables = ts.dump_tables()
    # remove existing individuals. We will reinstate them later
    tables.individuals.clear()
    tables.nodes.individual = np.full_like(tables.nodes.individual, tskit.NULL)

    _, node_map = ts.simplify(map_nodes=True)
    keep_nodes = np.where(node_map != tskit.NULL)[0]
    # Add an individual for each coalescent node, so we can run
    # simplify(keep_unary_in_individuals=True) to leave the unary portions in.
    for u in keep_nodes:
        i = tables.individuals.add_row()
        tables.nodes[u] = tables.nodes[u].replace(individual=i)
    node_map = tables.simplify(keep_unary_in_individuals=True)

    # Reinstate individuals
    tables.individuals.clear()
    for i in ts.individuals():
        tables.individuals.append(i)
    val, inverted_map = np.unique(node_map, return_index=True)
    inverted_map = inverted_map[val != tskit.NULL]
    tables.nodes.individual = ts.tables.nodes.individual[inverted_map]
    if map_nodes:
        return tables.tree_sequence(), node_map
    else:
        return tables.tree_sequence()

num_samples = 10
ts  = msprime.sim_ancestry(
    num_samples,
    ploidy=1,
    sequence_length=5e8,
    recombination_rate=1e-8,
    record_full_arg=True,
    random_seed=123
)

def oldest_mutation_node(site, ts_in):
    return max(site.mutations, key=lambda m: ts_in.node(m.node).time)

ts = simplify_keeping_unary_in_coal(ts)
tsA = msprime.sim_mutations(ts, rate=1e-8, random_seed=123)
tsB = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(tsA))

for variantA, variantB in zip(tsA.variants(), tsB.variants()):
    assert variantA.site.position == variantB.site.position
    allelesA = set(variantA.alleles)
    allelesB = set(variantB.alleles)
    assert allelesA == allelesB
    if (
        len(variantA.site.mutations) == 1 and
        np.sum(variantA.genotypes == 1) > 1  # ignore singletons
    ):
        oldest_mutation_A = oldest_mutation_node(variantA.site, tsA)
        oldest_mutation_B = oldest_mutation_node(variantB.site, tsB)
        node_from_A = oldest_mutation_A.node
        node_from_B = oldest_mutation_B.node
        # do something to compare the edge extent of node_from_A with node_from_B
        # Note that we might have multiple counting of the same node, however
        print(
            "comparable_ancestors: node",
            node_from_A,
            "from tsA, against node",
            node_from_B,
            "from tsB."
        )
hyanwong commented 2 years ago

Hmm, there's an issue here because in the code above I'm identifying the nodes-to-be-compared using the node below the mutation. But that could (often is) a unary node. I think we might want to compare the coalescent nodes below each mutation instead.

Here's some code that does that:

def compare_ancestors(ts_orig, ts_cmp):
    assert ts_orig.num_sites == ts_cmp.num_sites
    # Simplify so that the nodes for comparison are never unary
    ts_orig, orig_node_map = ts_orig.simplify(map_nodes=True)
    ts_cmp, cmp_node_map = ts_cmp.simplify(map_nodes=True)
    # we need to map new IDs -> old ones, so reverse the map
    orig_node_map = {j: i for i, j in enumerate(orig_node_map) if j >= 0}
    cmp_node_map = {j: i for i, j in enumerate(cmp_node_map) if j >= 0}

    comparable_nodes = []
    site_id = 0
    for interval, t_orig, t_cmp in ts_orig.coiterate(ts_cmp):
        while ts_orig.site(site_id).position < interval.right:
            site_orig = ts_orig.site(site_id)
            site_cmp = ts_cmp.site(site_id)
            assert site_orig.position == site_cmp.position
            assert alleles(site_orig) == alleles(site_cmp)
            # Only compare cases where there is one mutation which is not a singleton
            if len(site_orig.mutations) == 1 and t_orig.num_samples(site_orig.mutations[0].node) > 1:
                oldest_mutation_orig = oldest_mutation_node(site_orig, ts_orig)
                oldest_mutation_cmp = oldest_mutation_node(site_cmp, ts_cmp)
                node_from_orig = oldest_mutation_orig.node
                node_from_cmp = oldest_mutation_cmp.node
                assert t_orig.num_children(node_from_orig) > 1
                assert t_cmp.num_children(node_from_cmp) > 1
                comparable_nodes.append([
                    orig_node_map[node_from_orig],
                    cmp_node_map[node_from_cmp],
                    site_orig.position,
                ])
            site_id += 1
            if site_id >= ts_orig.num_sites:
                return comparable_nodes
    raise ValueError("did not inspect all sites")