tulerpetontidae / psmc-python

A reimplementation of a classical PSMC method in python for educational purposes
13 stars 2 forks source link

Inputting from a tree sequence or vcf-zarr file #2

Closed hyanwong closed 1 month ago

hyanwong commented 1 month ago

I love it that you have written a PSMC reimplementation in Python, thanks. It would great for my teaching purposes. I'm writing some material for teaching tree sequence / ARG-related concepts, and it would be helpful to also be able to read in data from a msprime (faster reimplementation of ms) simulation directly (i.e. in tree sequence format), or even as a vcf/zarr file. This would also allow easy testing of psmc against the standard population genetic simulation library (https://popsim-consortium.github.io/stdpopsim-docs/stable/introduction.html), which output tree sequences by default from SLiM or msprime.

If I make a PR to allow the following to work, would you consider it for incorporation?

import msprime
from psmc.utils import process_ts
# simulate a single 200Mb chromosome from one diploid individual
ts = msprime.sim_ancestry(1, sequence_length = 200e6, population_size=1e4, recombination_rate=1e-8)
ts = msprime.sim_mutations(ts, rate=1e-8)

xs = process_ts(ts)

# ... as before
hyanwong commented 1 month ago

Below is some outline code for a process_ts function. It's a bit complicated because I wanted to incorporate missing data. You should be able to simulate a tree sequence (and plot out the demography) by e.g.

import demes
import demesdraw
n0 = 1e4
L = 90e6
demes_obj = demes.from_ms(cmds[0], N0=n0)
demography = msprime.Demography.from_demes(demes_obj)
ts = msprime.sim_ancestry(1, demography=demography, sequence_length=L, recombination_rate=0.0002 / (4 * n0))
ts = msprime.sim_mutations(ts, rate=0.001 / (4 * n0))

# Plot it!
demesdraw.size_history(demes_obj, log_time=True)

# get an `xs` object for input into psmc-python
xs = process_ts(ts)

Here's the function

def process_ts(ts, individual=None, start=None, end=None, window_size=100, progress=False):
    """
    Turn the variation data from a specific ``individual`` in a tree sequence into a
    numpy array indicating the presence or absence of heterozygotes in 100bp
    windows from ``start`` to ``end``. If ``individual`` is None, simply pick
    the individual associated with the first sample node. If ``progress`` is True, show
    a pogressbar
    """
    tot, tot2 = 0, 0
    def is_connected(tree, node1, node2):
        """
        Check if neither node is isolated in the tree.
        """
        return not(tree.is_isolated(node1) or tree.is_isolated(node2))

    if not ts.discrete_genome:
        raise ValueError("Tree sequence must use discrete genome coordinates")
    if individual is None:
        individual = ts.node(ts.samples()[0]).individual
        if individual < 0:
            raise ValueError("No individual associate with the first sample node")
    try:
        nodes = ts.individual(individual).nodes
    except IndexError:
        raise ValueError(f"Individual {individual} not found in tree sequence")
    # Quickest to simplify to 2 genomes (gets rid of nonvariablt sites etc)
    ts = ts.simplify(samples = nodes)
    if ts.num_samples != 2:
        raise ValueError(f"Individual {individual} did not have 2 genomes")
    if start is None:
        start = 0
    if end is None:
        end = int(ts.sequence_length)
    if (end-start) % window_size != 0:
        print(
            f"Warning: the genome size is not a multiple of {window_size}. "
            "The last window will be skipped."
        )

    result = np.empty((1, int((end-start) // window_size)), dtype=np.int8)
    # Processing is complicated because we want to look at windows even if they contains
    # non-variable sites. We check for missing data by looking the tree at each site.
    tree_iter = ts.trees()
    tree = next(tree_iter)
    variant = next(ts.variants(copy=False))  # get a Variant obj
    assert variant.site.id == 0

    # place the tree iterator and the variant iterator at the start
    use_trees = True
    while tree.interval.right < start and use_trees:
        if tree.index < ts.num_trees - 1:
            tree = next(tree_iter)
        else:
            use_trees = False

    use_variants = True
    while variant.site.position < start and use_variants:
        # could probably jump to the right start point here
        if variant.site.id < ts.num_sites - 1:
            variant.decode(variant.site.id + 1)
        else:
            use_variants = False

    # Now iterate through the windows
    seq = np.zeros(window_size, dtype=np.int8)
    wins = np.arange(start, end, window_size)
    for i, (left, right) in tqdm(
        enumerate(zip(wins[:-1], wins[1:])),
        total=len(wins) - 1,
        desc=f"Calc {window_size}bp windows",
        disable=not progress,
    ):
        # 0=missing, 1=homozygous, 2=heterozygous
        if not use_trees:
            seq[:] = 0
        else:
            while (tree.interval.right < right):
                tree_left = int(tree.interval.left)
                tree_right = int(tree.interval.right)
                if tree_left < left:
                    seq[0: tree_right - left] = is_connected(tree, 0, 1)
                else:
                    seq[tree_left - left: tree_right - left] = is_connected(tree, 0, 1)
                if tree.index == ts.num_trees - 1:
                    use_trees = False
                    seq[tree_right - left: window_size] = 0
                else:
                    tree = next(tree_iter)
            if use_trees:
                l_pos = max(int(tree.interval.left) - left, 0)
                seq[l_pos:window_size] = is_connected(tree, 0, 1)
        if np.count_nonzero(seq == -1) != 0:
            print(tree.index, seq)
            raise ValueError()
        while use_variants and variant.site.position < right:
            pos = int(variant.site.position) - left
            if (variant.has_missing_data):
                seq[pos] = 0
            elif variant.genotypes[0] == variant.genotypes[1]:
                seq[pos] = 1
            else:
                # heterozygous
                seq[pos] = 2
                tot += 1
            if variant.site.id == ts.num_sites - 1:
                use_variants = False
            else:
                variant.decode(variant.site.id + 1)
        if np.count_nonzero(seq == 0) >= int(window_size * 0.9):
            result[0, i] = 2  # "N"
        elif np.count_nonzero(seq == 2) > 0:
            tot2 += 1
            result[0, i] = 1  # "K" = at least one heterozygote
        else:
            result[0, i] = 0  # "T" = all homozygous
    return result
tulerpetontidae commented 1 month ago

Hi @hyanwong. When I initially worked on this repo, my goal was to understand the PSMC method through reimplementation. I didn't anticipate that it might be useful to others in the future, so the repo lacks polish and annotations. It is also noticeably slower than the original C implementation and while the tests I run it produced similar results, I can not guarantee that the runs will be identical to the original psmc method.

However, I'm very glad you found this work useful and will be delighted to accept your suggestions for improving its usability as a teaching material.

hyanwong commented 1 month ago

Thanks, I completely understand about lack of polish! It's very handy to have a python implementation however. I'll make a PR with the tskit code (and it shouldn't require an extra import, which is nice).

I'm playing to see what the results are like on some real data (again, for teaching). I'll post code and results in another issue.