tskit-dev / tsinfer

Infer a tree sequence from genetic variation data.
GNU General Public License v3.0
54 stars 13 forks source link

Use imputation accuracy to determine the correct mismatch ratios #652

Open hyanwong opened 2 years ago

hyanwong commented 2 years ago

We currently have no simple objective way to determine what mismatch ratios (in match_samples and match_ancestors) to use in real data. This is quite limiting when building tree sequences from real data.

However, a natural metric to use is to look at the accuracy of imputation (e.g. as currently being worked on by @szhan). This would be an excellent way to determine the correct mismatch parameters to use. Here's how I think we would do it:

Take a real dataset, and mark (say) 1% of the genotype data to be missing, at random. Run tsinfer under different mismatch ratios (e.g. as here. Impute the missing data and calculate how well we do at imputation. Repeat for different random selections of 1% missing.

It would be helpful to do this both for different chromosomes, and for different datasets (e.g. separately on 1000G, HGDP, and SGDP). Hopefully the different chromosomes would show the same pattern.

It would also be interesting to see how we do for "weird" regions of the genome, such as the MHC.

jeromekelleher commented 2 years ago

We want to use several different metrics for imputation accuracy here, to avoid overfitting.

@szhan can you outline what the different metrics people use are, please?

hyanwong commented 2 years ago

Here's a simple example which I worked up with @percyfal. Does the result sound the right order of magnitude to you, @szhan ?

import msprime
import tsinfer
import tskit
import numpy as np

rho = 1e-6
mu = 1e-6

proportion_missing = 0.01
results = []
for rep in range(1, 101):
    # Make a simulated TS
    true_ts = msprime.sim_mutations(
        msprime.sim_ancestry(
            10, recombination_rate=rho, sequence_length=10000, population_size=1e4, random_seed=rep),
        rate=mu,
        random_seed=rep,
    )
    sd = tsinfer.SampleData.from_tree_sequence(true_ts)

    # Make one with some missing
    genos = sd.sites_genotypes[:]
    sd_missing = sd.copy()
    sd_missing.sites_genotypes[:] = np.where(np.random.random(genos.shape) < proportion_missing, tskit.MISSING_DATA, genos)
    sd_missing.finalise()

    ancs = tsinfer.generate_ancestors(sd_missing)
    #mm = 1e-10  # 
    #a_ts = tsinfer.match_ancestors(sd_missing, ancs, recombination_rate=rho, mismatch_ratio=mm)
    #imputed_ts = tsinfer.match_samples(sd_missing, a_ts, recombination_rate=rho, mismatch_ratio=mm)
    a_ts = tsinfer.match_ancestors(sd_missing, ancs)
    imputed_ts = tsinfer.match_samples(sd_missing, a_ts)
    num_misimputed = np.sum(imputed_ts.genotype_matrix(alleles=tskit.ALLELES_ACGT) != true_ts.genotype_matrix(alleles=tskit.ALLELES_ACGT))
    num_missing = np.sum(sd_missing.sites_genotypes[:] == tskit.MISSING_DATA)
    print(num_misimputed, "badly imputed out of", num_missing, ":", true_ts.num_sites, "sites", true_ts.num_trees, "trees")
    results.append(num_misimputed / num_missing)
print("Percent misinputed:", np.mean(results) * 100)

giving:

> Percent misinputed: 6.330269413992731
szhan commented 2 years ago

Imputation accuracy turns out to be a poor measure for evaluating imputation performance, because it does not account for chance agreement at low MAF. Imputation Quality Score (IQS) is a more popular measure. I'll give IQS a try and report back.

szhan commented 2 years ago

I have encountered three ways to measure imputation performance when the true genotypes are available.

  1. Concordance rate (CR)
  2. Imputation Quality Score (IQS)
  3. Squared correlation (SC)

CR is the simplest one, which is just the number of correctly imputed genotypes divided by the number of genotypes to impute. The problem with CR is that chance agreement can inflate it when MAF is low, especially < 0.01%. This may give the misleading impression that imputation accuracy is high.

IQS, on the other hand, accounts for chance agreement. It has been demonstrated to be a better way to show imputation performance at low MAF than CR (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2837741/). This is one of the measures we should use.

There is also SC, which is the square of the correlation coefficient between the allele dose (AD) of the true genotypes and the AD of the imputed genotypes. AD is the number of derived alleles in a genotype; in a diploid organism, the AD of 0|0 is 0, 0|1 and 1|0 is 1, and 1|1 is 2. SC can be computed per imputed site, and then a mean SC can be obtained by taking the average across all the imputed sites. Mean SC was used in this pre-phasing paper (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3696580/).

I'm looking into other ways to measure imputation performance when the true genotypes are not available, in particular, BEAGLE's estimated mean R2 and IMPUTE's INFO score. I have seen the INFO score used quite a bit.

szhan commented 2 years ago

Another note about squared correlation (SC). SC can be useful if we are interested in how well imputation correctly predicts the allele dose, which is relevant to GWAS that estimates the effect size of allele dose of each genetic variant on a phenotype of interest.

szhan commented 2 years ago

Some studies assessing genotype imputation focus on genotypes with non-reference alleles. For example, Li et al. (2021) (see https://pubmed.ncbi.nlm.nih.gov/33536225/) computes non-reference concordance (NRC), which does not consider genotypes homozygous for the reference allele. Also, Priit Palta's group has used related metrics, non-reference sensitivity and non-reference discordancy, which are calculated by the GATK GenotypeConcordance tool, which I plan to try out.

hyanwong commented 1 year ago

@shing : here is some code to modify to parallelise over the various mismatch values and assess the quality of imputation. We would do better, however, to also calculate the IQS rather than simply the fraction of correctly imputed sites:

Run in parallel by giving the -p n_cpus flag, e.g. ./myscript.py -p 16 on a 16 core VM. The -b flag should give a progress bar. Change the number of replicates etc as you see fit.

import argparse
import multiprocessing
import itertools
import operator
import tempfile
import functools
import collections
import os
import logging

import numpy as np
import stdpopsim
import tsinfer
import tskit
from tqdm import tqdm
import pandas as pd

Params = collections.namedtuple("Params", "mismatch_ancestors, mismatch_samples")

def mark_missing_for_imputation(sample_data, frac_missing, random_seed):
    """
    Take a sample data file and randomly mark a fraction of genotypes as missing
    """
    np.random.seed(random_seed)
    sd = sample_data.copy()
    n_sites, n_samples = sd.sites_genotypes.shape
    n_genotypes = n_sites*n_samples
    n_to_inpute = int(frac_missing * n_genotypes)
    assert n_genotypes > 0
    assert n_to_inpute > 0
    random_indexes = np.random.choice(n_genotypes, replace=False, size=n_to_inpute)
    impute_indexes = np.unravel_index(random_indexes, sd.sites_genotypes.shape)

    genotypes = sd.sites_genotypes[:]  # warning - could be a huge array!
    # check no geneotypes are missing in the original data: we could relax this
    n_missing = np.sum(genotypes == tskit.MISSING_DATA)
    assert n_missing == 0
    genotypes[impute_indexes] = tskit.MISSING_DATA
    sd.sites_genotypes[:] = genotypes
    sd.finalise()
    # create a dict of {site_id: imputed_samples}
    imputed = {
        site_id: impute_indexes[1][impute_indexes[0]==site_id]
        for site_id in np.unique(impute_indexes[0])
    }
    return sd, imputed

def worker(index_and_mismatch_params, rate_map, sd_path):
    logging.info(f"Worker started with {index_and_mismatch_params}")
    # carry out a new imputation for each parameter combinations
    index, params = index_and_mismatch_params
    orig_sd = tsinfer.load(sd_path)
    sample_data, imputed = mark_missing_for_imputation(
        orig_sd, frac_missing=0.01, random_seed=index+123)
    logging.info(
        f"Set {sum([len(v) for v in imputed.values()])} random genotypes to missing")
    logging.info("Inference of missing data: generating ancestors")
    ancestor_data = tsinfer.generate_ancestors(sample_data)

    anc_ts = tsinfer.match_ancestors(
        sample_data,
        ancestor_data,
        mismatch_ratio=params.mismatch_ancestors,
        recombination_rate=rate_map,
    )
    if anc_ts.num_sites == 0:
        raise ValueError(
            f"No sites left for inference for params {index_and_mismatch_params}")
    ts = tsinfer.match_samples(
        sample_data,
        anc_ts,
        mismatch_ratio=params.mismatch_samples,
        recombination_rate=rate_map,
    )

    # maybe save intermediate results here? That way we can reload the tree
    # sequences to test different imputation metrics without having to re-infer

    # make a boolean array of whether the site was imputed
    site_imputed = np.zeros(ts.num_sites, dtype=bool)
    site_imputed[list(imputed.keys())] = True
    num_imputed = num_correctly_imputed = 0
    for var_orig, var_new in itertools.compress(
        zip(orig_sd.variants(), ts.variants()), site_imputed
    ):
        assert var_orig.site.id == var_new.site.id
        assert var_orig.site.position == var_new.site.position
        imputed_samples = imputed[var_orig.site.id]
        orig_alleles = np.array(var_orig.alleles)
        new_alleles = np.array(var_new.alleles)
        orig_decoded = orig_alleles[var_orig.genotypes[imputed_samples]]
        new_decoded = new_alleles[var_new.genotypes[imputed_samples]]
        num_imputed += len(imputed_samples)
        num_correctly_imputed += np.sum(orig_decoded == new_decoded)
        # Here we should instead use some sort of imputation metric

    return index_and_mismatch_params, num_imputed, num_correctly_imputed

def save_results(results_table, result, path=None):
    """
    Save the result in the spcified results_table.
    Assumes that "result" is laid out as returned by the worker() function
    """
    (i, params), num_imputed, num_correctly_imputed = result
    results_table.loc[i][Params._fields] = params
    results_table.loc[i]["frac_correct"] = num_correctly_imputed / num_imputed
    # Save current results
    if path:
        results_table.to_csv(path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='count', default=0)
    parser.add_argument('-p', '--num_procs', type=int, default=0)
    parser.add_argument('-b', '--progress-bar', action="store_true")
    args = parser.parse_args()
    levels = [logging.WARNING, logging.INFO, logging.DEBUG]
    level = levels[min(args.verbose, len(levels) - 1)]  # cap to last level index
    logging.basicConfig(level=level)
    multiprocessing.log_to_stderr(level=level)

    n_cpus = 2
    n_reps = 2
    mismatch_params = [10 ** exponent for exponent in range(-5, 5)]
    mismatch_combos = [p.ravel() for p in np.meshgrid(mismatch_params, mismatch_params)]
    assert len(mismatch_combos) == 2
    assert len(mismatch_combos[0]) == len(mismatch_combos[1])
    # Now replicate those for the given number of replicates
    mismatch_combos = [np.tile(p, n_reps) for p in mismatch_combos]

    ## Simulate some data - this could be replaced with e.g. reading a VCF of real data
    species = stdpopsim.get_species("HomSap")
    contig = species.get_contig("chr20", length_multiplier=0.01)
    print(f"Simulating {contig.recombination_map.sequence_length} bp")
    model = species.get_demographic_model('OutOfAfrica_3G09')
    samples = model.get_samples(50, 50, 50)
    engine = stdpopsim.get_engine('msprime')
    ts = engine.simulate(model, contig, samples, seed=123)
    print(f"Simulated {ts.num_sites} sites for {ts.num_samples} haplotypes")

    save_path="results.csv"
    with tempfile.TemporaryDirectory() as tmp:
        # Create the file where all subprocesses can get at it.
        sample_data_path = os.path.join(tmp, 'data.samples')
        tsinfer.SampleData.from_tree_sequence(ts, path=sample_data_path)

        call_func = functools.partial(
            worker, rate_map=contig.recombination_map, sd_path=sample_data_path)
        func_params = enumerate(map(Params, *mismatch_combos))
        num_results = len(mismatch_combos[0])
        column_names = list(Params._fields) + ["frac_correct"]
        results_table = pd.DataFrame(index=np.arange(num_results), columns=column_names)

        if args.num_procs == 0:
            for result in tqdm(
                map(call_func, func_params),
                total=num_results,
                disable=not args.progress_bar,
            ):
                save_results(results_table, result, save_path)
        else:
            with multiprocessing.Pool(
                processes=n_cpus,
                maxtasksperchild=1,  # may not be needed?
            ) as pool:
                for result in tqdm(
                    pool.imap_unordered(call_func, func_params, chunksize=1),
                    total=num_results,
                    disable=not args.progress_bar,
                ):
                    save_results(results_table, result, save_path)

    print(results_table)
    print(f"Saved to {save_path}")

And here's how to plot the result e.g. in a Jupyter notebook. It should work even if not all the replicates have been run.

import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv("results.csv")
mean_vals = df.groupby(["mismatch_ancestors", "mismatch_samples"], as_index=False).mean()
grid_data = mean_vals.pivot("mismatch_ancestors", "mismatch_samples", "frac_correct")
plt.contourf(grid_data.columns, grid_data.index, grid_data, 20, cmap='viridis')
plt.xscale("log")
plt.xlabel(grid_data.columns.name)
plt.yscale("log")
plt.ylabel(grid_data.index.name)
plt.show()
hyanwong commented 1 year ago

To make sure that we are not being biased by choice of samples or genomic region when determining mismatch rates, I think that if we are subsetting the data, we should randomly chose both a subset of samples and a subset of contiguous sites when performing inference.

Since we will be recommending that people do this to establish the right mismatch rates to use, I think we could also consider implementing a utility function to do this (including in parallel) in eval_util.py

hyanwong commented 1 year ago

From Jin & Terhorst: https://www.biorxiv.org/content/10.1101/2022.08.03.502674v1.full

Screenshot 2023-03-14 at 09 50 52

We should do this for the imputation calculation.