tskit-dev / tskit

Population-scale genomics
MIT License
151 stars 70 forks source link

Method to position mutations to explain genotypes #99

Closed jeromekelleher closed 5 years ago

jeromekelleher commented 5 years ago

We want to make tskit usable for a wide variety of applications. One potentially large group of users is the viral and bacterial genomics community. In these applications there will often be good estimate of a tree for the samples. We want to be able to represent data by combining such a tree (converted to tree sequence form using tsconvert.from_newick) with the observed sequences.

The key algorithm that we need is a method to place mutations on the tree for a given site. Luckily this is a well-studied classical phylogenetics question. Here is a rough prototype of the Sankoff algorithm:

import msprime
import tskit
import numpy as np

class Site(object):
    def __init__(self):
        self.ancestral_state = None
        self.mutations = []

class Mutation(object):
    def __init__(self, node=None, derived_state=None):
        self.node = node
        self.derived_state = derived_state

def sankoff(tree, genotypes, alleles, distance_matrix=None):
    """
    Returns a Site object explaining the specified set of genotypes
    on the specified tree.

    Based on treatment from Clemente et al., https://doi.org/10.1186/1471-2105-10-51
    """
    num_alleles = np.max(genotypes) + 1
    if distance_matrix is None:
        distance_matrix = np.ones((num_alleles, num_alleles))
        np.fill_diagonal(distance_matrix, 0)
    S = np.zeros((num_alleles, tree.num_nodes))
    infinity = 1e7  # Arbitrary big value
    # Initialise the weights
    for allele, u in zip(genotypes, tree.tree_sequence.samples()):
        S[:, u] = infinity
        S[allele, u] = 0
    for p in tree.nodes(order="postorder"):
        for i in range(num_alleles):
            for child in tree.children(p):
                min_w = infinity
                for j in range(num_alleles):
                    min_w = min(min_w, distance_matrix[i, j] + S[j, child])
                S[i, p] += min_w

    site = Site()
    S_anc = [None for _ in range(tree.num_nodes)]
    for x in tree.nodes(order="preorder"):
        S_anc[x] = np.argmin(S[:, x])
        i = S_anc[x]
        min_cost = infinity
        for j in range(num_alleles):
            if x == tree.root:
                trans_cost = S[j, x]
            else:
                trans_cost = distance_matrix[i, j] + S[j, x]
            if trans_cost < min_cost:
                min_cost = trans_cost
                S_anc[x] = j
        if x == tree.root:
            site.ancestral_state = alleles[S_anc[x]]
        elif S_anc[x] != S_anc[tree.parent(x)]:
            # TODO track the mutation parent. Should be straightforward enough.
            site.mutations.append(Mutation(node=x, derived_state=alleles[S_anc[x]]))
    return site

# Try it out.
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
tree = ts.first()

print("num_sites = ", ts.num_sites)
tables = ts.dump_tables()
tables.sites.clear()
tables.mutations.clear()

# print(tree.draw(format="unicode"))
for variant in ts.variants():
    site = sankoff(tree, variant.genotypes, variant.alleles)
    site_id = tables.sites.add_row(variant.site.position, ancestral_state=site.ancestral_state)
    for mutation in site.mutations:
        tables.mutations.add_row(
            site=site_id, node=mutation.node, derived_state=mutation.derived_state)

ts2 = tables.tree_sequence()
assert np.array_equal(ts.genotype_matrix(), ts2.genotype_matrix())
# For simple infinite sites mutations we should recover exactly.
assert tables.sites == ts.tables.sites
assert tables.mutations == ts.tables.mutations

Some questions:

  1. Does this belong in tskit, or should it go somewhere else?
  2. Assuming it does belong in tskit, what should we call it? (Presumably a method of the Tree class?)

Any thoughts much appreciated!

petrelharp commented 5 years ago

This is awesome and extremely useful.