JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

Fix some type annotations causing failing tests #456

Closed stephen-huan closed 2 months ago

stephen-huan commented 2 months ago

Type of changes

Checklist

Description

Fix some incorrect type annotations causing errors when running with the most recent version of beartype (0.18.5).

ImportError while loading conftest '/.../GPJax/tests/conftest.py'.
tests/conftest.py:8: in <module>
    import gpjax  # noqa: F401
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/__init__.py:15: in <module>
    from gpjax import (
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/base/__init__.py:16: in <module>
    from gpjax.base.module import (
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/base/module.py:1: in <module>
    ???
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:397: in jaxtyped
    full_fn = typechecker(full_fn)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcache.py:77: in beartype
    return beartype_object(obj, conf)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:87: in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:136: in _beartype_object_fatal
    beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:182: in beartype_nontype
    return beartype_func(obj, **kwargs)  # type: ignore[return-value]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:247: in beartype_func
    func_wrapper_code = generate_code(bear_call)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/wrapmain.py:122: in generate_code
    code_check_return = _code_check_return(bear_call)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:237: in code_check_return
    reraise_exception_placeholder(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/error/utilerrraise.py:138: in reraise_exception_placeholder
    raise exception.with_traceback(exception.__traceback__)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:174: in code_check_return
    ) = make_code_raiser_func_pith_check(  # type: ignore[assignment]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached
    raise exception
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/checkmake.py:311: in make_code_raiser_func_pith_check
    ) = make_check_expr(hint, conf, cls_stack)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached
    raise exception
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/code/codemake.py:1578: in make_check_expr
    hint_childs = get_hint_pep484585_args(  # type: ignore[assignment]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/hint/pep/proposal/pep484585/utilpep484585.py:158: in get_hint_pep484585_args
    raise BeartypeDecorHintPep585Exception(
E   beartype.roar.BeartypeDecorHintPep585Exception: Class method gpjax.base.module.check_return() return PEP 585 type hint dict[str] not subscripted (indexed) by 2 arguments (i.e., subscripted by 1 != 2 arguments).
=================================== FAILURES ===================================
_________ test_thompson_sampling_non_conjugate_posterior_raises_error __________

args = (ThompsonSampling(num_features=100),)
kwargs = {'datasets': {'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, 'key': Array((), dtype=key<fry>) overla...       [-0.35197108],
       [-0.46512772],
       [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}}
bound = <BoundArguments (self=ThompsonSampling(num_features=100), posteriors={'OBJECTIVE': NonConjugatePosterior(prior=Prior(k...s={'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, key=Array((), dtype=key<fry>) overlaying:
[ 0 42])>
memos = ({}, {}, {}, {'datasets': {'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, 'key': Array((), dtype=key...      [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}, 'self': ThompsonSampling(num_features=100)})
argmsg = "\nThe problem arose whilst typechecking parameter 'posteriors'.\nActual value: { 'OBJECTIVE': NonConjugatePosterior(p...       key=Array([ 0, 42], dtype=uint32))}\nExpected
 type: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior]."
name = 'build_utility_function'
param_values = "{ 'datasets': {'OBJECTIVE': - Number of observations: 10\n- Input dimension: 1},\n  'key': Array((), dtype=key<fry>) ...                                   key=Array([
 0, 42], dtype=uint32))},\n  'self': ThompsonSampling(num_features=100)}"
param_hints = "(self, posteriors: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], datasets: collections.abc.Mapping[str, gpjax.dataset.Dataset], key: Union[UInt32[Array,
'2'], Key[Array, '']])"
msg = "Type-check error whilst checking the parameters of build_utility_function.\nThe problem arose whilst typechecking par...or], datasets: collections.abc.Mapping[str, gpjax.datas
et.Dataset], key: Union[UInt32[Array, '2'], Key[Array, '']]).\n"

    @ft.wraps(fn)
    def wrapped_fn(*args, **kwargs):
        if config.jaxtyping_disable:
            return fn(*args, **kwargs)

        # Raise bind-time errors before we do any shape analysis. (I.e. skip
        # the pointless jaxtyping information for a non-typechecking failure.)
        bound = param_signature.bind(*args, **kwargs)

        memos = push_shape_memo(bound.arguments)
        try:
            # First type-check just the parameters before the function is
            # called.
            try:
>               param_fn(*args, **kwargs)
E               beartype.roar.BeartypeCallHintParamViolation: Method gpjax.decision_making.utility_functions.thompson_sampling.check_params() parameter posteriors={'OBJECTIVE': NonCo
njugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelCompu...))} violates type hint collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], as dict key str 'OBJECTIV
E' value <protocol "gpjax.gps.NonConjugatePosterior"> "NonConjugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelComputation(), act...))" not instance of <protocol "gpja
x.gps.ConjugatePosterior">.

/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:418: BeartypeCallHintParamViolation