astheeggeggs / lshmm

code to run Li and Stephens
MIT License
3 stars 3 forks source link

Refactor and improve `check_inputs` #45

Closed szhan closed 3 weeks ago

szhan commented 1 month ago

I think the docstring should spell the current requirements for the input reference panel and query better. We should add more assertions to check that (1) the input reference panel contains only the values 0, 1, and 2 (which are the difference allele dosages for biallelic sites); (2) the ref. panel cannot include any MISSING values; (3) the input query does not contain any NONCOPY values. We should also improve the error messages when raised. There are likely more easy improvements to make.

szhan commented 1 month ago

@astheeggeggs suggested that queries should be checked and an error message should be raised if queries containing only MISSING are taken as input.

szhan commented 1 month ago

This should do:

def check_inputs(
    reference_panel,
    query,
    prob_recombination,
    prob_mutation=None,
    scale_mutation_rate=None,
):
    """
    Check that the input data and parameters are valid.

    The reference panel must be a matrix of size (m, n) or (m, n, n), and
    the query must be a matrix of size (k, m), where

        m = number of sites.
        n = number of samples in the reference panel (haplotypes, not individuals).
        k = number of samples in the query (haplotypes, not individuals).

    The mutation rate can be scaled according to the set of alleles
    that can be mutated to based on the number of distinct alleles at each site.

    :param numpy.ndarray reference_panel: Matrix of size (m, n) or (m, n, n).
    :param numpy.ndarray query: Matrix of size (k, m).
    :param numpy.ndarray prob_recombination: Recombination probability.
    :param numpy.ndarray prob_mutation: Mutation probability. If None (default), set as per Li & Stephens (2003).
    :param bool scale_mutation_rate: Scale mutation rate if True (default).
    :return: Number of reference haplotypes, number of sites, ploidy
    :rtype: tuple
    """
    if scale_mutation_rate is None:
        scale_mutation_rate = True

    # Check the reference panel.
    if not len(reference_panel.shape) in (2, 3):
        err_msg = "Reference panel array must have 2 or 3 dimensions."
        raise ValueError(err_msg)

    if len(reference_panel.shape) == 2:
        num_sites, num_ref_haps = reference_panel.shape
        ploidy = 1
    else:
        num_sites, num_ref_haps, _num_samples = reference_panel.shape
        if num_ref_haps != _num_samples:
            err_msg = (
                "Reference_panel dimensions are incorrect. "
                "An array of size (m, n, n) is expected, "
                "where m = number of sites and n = number of samples."
            )
            raise ValueError(err_msg)
        ploidy = 2

    # Check the query.
    if query.shape[1] != num_sites:
        err_msg = (
            "Number of sites in the query does not match reference panel's."
        )
        raise ValueError(err_msg)

    if np.all(query == core.MISSING):
        err_msg = "Query cannot contain only MISSING."
        raise ValueError(err_msg)

    # Check mutation probability.
    if isinstance(prob_mutation, (int, float)):
        if not scale_mutation_rate:
            warn_msg = "Scalar mutation probability is passed, but not rescaling it."
            warnings.warn(warn_msg)
    elif len(prob_mutation) == num_sites:
        if scale_mutation_rate:
            warn_msg = "An array of mutation probabilities is passed. Rescaling them."
            warnings.warn(warn_msg)
    elif prob_mutation is None:
        warn_msg = (
            "No mutation probability is passed. "
            "Setting it based on Li & Stephens (2003) equations A2 and A3."
        )
        warnings.warn(warn_msg)
    else:
        err_msg = (
            "Mutation probability is not None, a scalar, or "
            "an array of length equal to the number of sites."
        )
        raise ValueError(err_msg)

    # Check recombination probability.
    if not (
        isinstance(prob_recombination, (int, float))
        or (
            isinstance(prob_recombination, np.ndarray)
            and prob_recombination.shape[0] == num_sites
        )
    ):
        err_msg = (
            "Recombination probability is not a scalar or "
            "an array of length equal to the number of sites."
        )
        raise ValueError(err_msg)

    return (num_ref_haps, num_sites, ploidy)
szhan commented 3 weeks ago

When running diploid LS, the input should be implied phased genotypes for the reference panel and unphased genotypes for one query, I think. So, the expected size of the ref. panel is (m, n, n) and that of the query is (1, m).

Or do we intend to have the function run over multiple queries?

astheeggeggs commented 3 weeks ago

The function should definitely eventually run over multiple queries, yep.

Sent from Outlook for Androidhttps://aka.ms/AAb9ysg


From: Shing Hei Zhan @.> Sent: Tuesday, June 11, 2024 3:17:28 PM To: astheeggeggs/lshmm @.> Cc: Duncan Palmer @.>; Mention @.> Subject: Re: [astheeggeggs/lshmm] Refactor and improve check_inputs (Issue #45)

When running diploid LS, the input should be implied phased genotypes for the reference panel and unphased genotypes for one query, I think. So, the expected size of the ref. panel is (m, n, n) and that of the query is (1, m).

Or do we intend to have the function run over multiple queries?

— Reply to this email directly, view it on GitHubhttps://github.com/astheeggeggs/lshmm/issues/45#issuecomment-2160885011, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABVQA753PJVNJPPFDMTPHJDZG4BHRAVCNFSM6AAAAABIFDDUYWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNRQHA4DKMBRGE. You are receiving this because you were mentioned.Message ID: @.***>

szhan commented 3 weeks ago

But it's not actually running over multiple queries once passed to the forward function, for example. We have to modify the main API functions to do the looping.

szhan commented 3 weeks ago

Also, do we want the argument scale_mutation_rate here? It only issues warnings.

astheeggeggs commented 3 weeks ago

Yeah, it's not, but it's trivial to do so. That was also the reasoning for why the query sequence was stored as a matrix rather than a vector.

On Tue, 11 Jun 2024, 15:57 Shing Hei Zhan, @.***> wrote:

But it's not actually running over multiple queries once passed to the forward function, for example. We have to modify the main API functions to do the looping.

— Reply to this email directly, view it on GitHub https://github.com/astheeggeggs/lshmm/issues/45#issuecomment-2160979705, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVQA766MDRZCKG5ZB5MVFDZG4F4ZAVCNFSM6AAAAABIFDDUYWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNRQHE3TSNZQGU . You are receiving this because you were mentioned.Message ID: @.***>

szhan commented 3 weeks ago

Yes, it should be addressed in a separate issue though.