PlantandFoodResearch / MCHap

Polyploid micro-haplotype assembly using Markov chain Monte Carlo simulation.
MIT License
18 stars 3 forks source link

Simplify PMF for calling Gibbs sampler #125

Closed timothymillar closed 2 years ago

timothymillar commented 2 years ago

Related top #124

Currently using the full form of the DirMul PMF in mchap call. This can be simplified based on the fact that only a single parameter (i.e., allele) is being updated at a time. Simplifying should result in a moderate performance improvement.

The naive simplification is

@numba.njit
def log_genotype_allele_prior(
    genotype, variable_allele, unique_haplotypes, inbreeding=0
):
    """Log probability that a genotype contains a specified allele
    given its other alleles are treated as constants.
    This prior function is designed to be used in a gibbs sampler in which a single allele
    is resampled at a time.
    Parameters
    ----------
    genotype : ndarray, int, shape (ploidy, )
        Integer encoded alleles in the proposed genotype.
    variable_allele : int
        Index of the allele that is not being held constant ( < ploidy).
    unique_haplotypes : int
        Number of possible haplotype alleles at this locus.
    inbreeding : float
        Expected inbreeding coefficient of the sample.
    Returns
    -------
    lprior : float
        Log-probability of the variable allele given the observed alleles.
    """
    # base alpha for flat prior
    alpha = inbreeding_as_dispersion(inbreeding, unique_haplotypes)

    # sum of alpha parameters acounting for constant alleles
    constant = np.delete(genotype, variable_allele)
    counts = allelic_dosage(constant)
    sum_alpha = np.sum(counts + alpha) + alpha * (unique_haplotypes - len(counts))

    # alpha parameter for variable allele
    count = count_allele(genotype, genotype[variable_allele]) - 1
    variable_alpha = alpha + count

    # dirichlet-multinomial PMF
    left = lgamma(sum_alpha) - lgamma(1 + sum_alpha)
    right = lgamma(1 + variable_alpha) - lgamma(variable_alpha)
    return left + right

Based on the Wikipedia article this can be simplified further to remove the gamma function entirely

@numba.njit
def genotype_allele_prior(
    genotype, variable_allele, unique_haplotypes, inbreeding=0
):
    """Probability that a genotype contains a specified allele
    given its other alleles are treated as constants.
    This prior function is designed to be used in a gibbs sampler in which a single allele
    is resampled at a time.
    Parameters
    ----------
    genotype : ndarray, int, shape (ploidy, )
        Integer encoded alleles in the proposed genotype.
    variable_allele : int
        Index of the allele that is not being held constant ( < ploidy).
    unique_haplotypes : int
        Number of possible haplotype alleles at this locus.
    inbreeding : float
        Expected inbreeding coefficient of the sample.
    Returns
    -------
    prior : float
        Probability of the variable allele given the observed alleles.
    """
    alpha = inbreeding_as_dispersion(inbreeding, unique_haplotypes)

    ploidy = len(genotype)
    count = count_allele(genotype, genotype[variable_allele]) - 1
    num = count + alpha

    sum_alpha = alpha * unique_haplotypes
    denom = sum_alpha + ploidy - 1
    return num / denom
timothymillar commented 2 years ago

Naive version added in #129, can reopen later if deemed worth updating tests.