tskit-dev / tsdate

Infer the age of ancestral nodes in a tree sequence.
MIT License
19 stars 10 forks source link

Comparison to "Threads" #430

Open hyanwong opened 2 weeks ago

hyanwong commented 2 weeks ago

I thought it would be interesting to compare how tsinfer+tsdate do against the new "Threads" program (https://pypi.org/project/threads-arg/0.1.0/).

As a test, I have a stdpopsim model with 2 chimp species (3 populations) and some selective sweeps in 2 of the populations:

Screenshot 2024-08-25 at 17 20 40

Here are the edge plots from the original vs tsinfer+tsdate and Threads on 5Mb of genome: the 2 recent selective sweeps (in western at 1/3rd along the genome, and bonobo at 2/3rds) are obvious in the plot of the true edges. Threads seems a little better in picking up the recent sweeps, but maybe picks up less of the demographic banding at the top (although to be fair, I told it there was a fixed population size of 100,000 haploid genomes). For this scale of data (120 genomes), Threads is highly faster than tsinfer (and it is relatively faster if the genome length increases, but I suspect we would see a relative slow down if the sample size increased to very large numbers)

Screenshot 2024-08-25 at 17 24 12
Click for code ```python import os import subprocess import tempfile import stdpopsim import msprime import tskit import numpy as np import pandas as pd import tszip import matplotlib.pyplot as plt import matplotlib as mpl import tsinfer import tsdate def sweep_and_demography_Pan( sequence_length, sample_sizes, # could be a single number, or a dict of pop_name: size sweep_params=None, chrom="chr3", random_seed=123, ): """ Make an example of a 3-population reasonably complex demography with a few more recent selective sweeps """ species = stdpopsim.get_species("PanTro") model = species.get_demographic_model("BonoboGhost_4K19") msprime_demography = model.model contig = species.get_contig("chr3", mutation_rate=model.mutation_rate) ratemap = contig.recombination_map.slice(right=L, trim=True) try: sample_sizes = { name: int(sample_sizes) for name, info in msprime_demography.items() # Don't sample from the ghost population if info.extra_metadata['sampling_time'] is not None } except TypeError: pass G = 4000 # A time ago in generations: we assume populations from time 0..G are isolated and of constant size # Make independent populations, some with selective sweeps independent_pop_ts = [] for name, pop in model.model.items(): if name in sample_sizes: Ne = pop.initial_size demog = msprime.Demography() demog.add_population(name=name, initial_size=Ne) if name in sweep_params: p = 1 / (2 * Ne) freqs = {"start_frequency": p, "end_frequency": 1 - p, "dt": 1 / (40 * Ne)} sweep_model = msprime.SweepGenicSelection(**freqs, **sweep_params[name]) models = (sweep_model, msprime.StandardCoalescent()) print(f"Adding {name} population to demographic model, sweep at {int(sweep_params[name]['position'])}bp, selection coefficient s={sweep_params[name]['s']}") else: models = msprime.StandardCoalescent() print(f"Adding {name} population to demographic model, neutral") independent_pop_ts.append(msprime.sim_ancestry( sample_sizes[name], model=models, demography=demog, recombination_rate=ratemap, sequence_length=sequence_length, end_time=G, random_seed=123, )) combined_ts = independent_pop_ts[0] for ts in independent_pop_ts[1:]: combined_ts = combined_ts.union(ts, node_mapping=np.full(ts.num_nodes, tskit.NULL)) # Now recapitate: initial_state uses the population names in the combined_ts to figure out which are which ts = msprime.sim_ancestry(initial_state=combined_ts, demography=msprime_demography, random_seed=random_seed).simplify() return msprime.sim_mutations(ts, rate=model.mutation_rate, random_seed=random_seed), model, ratemap L = 5e6 # Simulate 5 Mb sweep_params = { "western": {"position": L//3, "s": 0.1}, "bonobo": {"position": (2*L)//3, "s": 0.05}, } ts, model, ratemap = sweep_and_demography_Pan( sequence_length=L, sample_sizes=20, sweep_params=sweep_params ) print(f"Simulated {ts.num_sites} sites") def run_threads(input_ts, ratemap, demography): "Run Threads: currently 'demography' is hard-coded in as a hack" # remove multiallelics ts = input_ts.delete_sites([s.id for s in input_ts.sites() if len(s.alleles) != 2]) # Make .pgen & .pvar files n_dip_indv = int(ts.num_samples / 2) indv_names = [f"tsk_{i}indv" for i in range(n_dip_indv)] with tempfile.TemporaryDirectory() as tmpdirname: tmp_fn_prefix = os.path.join(tmpdirname, "tmp") with open(f"{tmp_fn_prefix}.vcf", "wt") as vcf_file: ts.write_vcf(vcf_file, individual_names=indv_names) subprocess.call(["./plink2", "--vcf", f"{tmp_fn_prefix}.vcf", "--out", tmp_fn_prefix]) # Make a map with a position for each SNP df = pd.DataFrame({ "chr": np.repeat("Chr1", ts.num_sites), "SNP": np.arange(ts.num_sites), "cM": ratemap.get_cumulative_mass(ts.sites_position) * 100, "bp": ts.sites_position.astype(int), }) df.to_csv(f"{tmp_fn_prefix}.map.gz", sep="\t", index=False, header=False) # Hack a demography file times = np.array([0, 1e6]) diploid_size = np.array([50_000, 50_000]) df = pd.DataFrame({"gens_ago": times.astype(int), "haploid_Ne": diploid_size.astype(int) * 2}) df.to_csv(f"{tmp_fn_prefix}.demo", sep="\t", index=False, header=False) subprocess.call([ "threads", "infer", "--pgen", f"{tmp_fn_prefix}.pgen", "--map_gz", f"{tmp_fn_prefix}.map.gz", "--demography", f"{tmp_fn_prefix}.demo", "--out", f"{tmp_fn_prefix}.threads", ]) subprocess.call([ "threads", "convert", "--threads", f"{tmp_fn_prefix}.threads", "--tsz", f"{tmp_fn_prefix}.tsz", ]) return tszip.decompress(f"{tmp_fn_prefix}.tsz") threads_ts = run_threads(ts, ratemap, None) tsinfer_ts = tsinfer.infer( tsinfer.SampleData.from_tree_sequence(ts), num_threads=6, progress_monitor=True, ) tsinfer_tsdate_ts = tsdate.date( tsdate.preprocess_ts(tsinfer_ts), mutation_rate=model.mutation_rate, rescaling_intervals=100, ) # reinfer tsinfer2_ts = tsinfer.infer( tsinfer.SampleData.from_tree_sequence(tsinfer_tsdate_ts, use_sites_time=True), num_threads=6, progress_monitor=True, ) tsinfer2_tsdate_ts = tsdate.date( tsdate.preprocess_ts(tsinfer2_ts), mutation_rate=model.mutation_rate, rescaling_intervals=100, ) def edge_plot(plot_ts, ax): tm = plot_ts.nodes_time[plot_ts.edges_parent] ax.add_collection( mpl.collections.LineCollection( np.array([[plot_ts.edges_left, plot_ts.edges_right], [tm, tm]]).T, alpha=0.2 ) ) ax.autoscale() ax.margins(0) ax.set_yscale("log") fig, (ax_orig, ax_tsinfer, ax_threads) = plt.subplots(3, 1, figsize=(15, 15), sharex=True) edge_plot(ts, ax_orig) ax_orig.set_title(f"True edges & times ({ts.num_edges} edges)") edge_plot(tsinfer2_tsdate_ts, ax_tsinfer) ax_tsinfer.set_title(f"Tsinfer + tsdate ({tsinfer2_tsdate_ts.num_edges} edges)") edge_plot(threads_ts.simplify(), ax_threads) ax_threads.set_title(f"Threads {threads_ts.simplify().num_edges} edges"); ax_threads.set_xlabel("Genome position") ```
hyanwong commented 2 weeks ago

Incidentally, here's how tsinfer+tsdate does if we feed in the true times

Screenshot 2024-08-25 at 17 35 37