Closed stephen-huan closed 2 months ago
poetry run pre-commit run --all-files --show-diff-on-failure
Fix some incorrect type annotations causing errors when running with the most recent version of beartype (0.18.5).
ImportError while loading conftest '/.../GPJax/tests/conftest.py'. tests/conftest.py:8: in <module> import gpjax # noqa: F401 /nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module return super().exec_module(module) gpjax/__init__.py:15: in <module> from gpjax import ( /nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module return super().exec_module(module) gpjax/base/__init__.py:16: in <module> from gpjax.base.module import ( /nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module return super().exec_module(module) gpjax/base/module.py:1: in <module> ??? /nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:397: in jaxtyped full_fn = typechecker(full_fn) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcache.py:77: in beartype return beartype_object(obj, conf) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:87: in beartype_object _beartype_object_fatal(obj, conf=conf, **kwargs) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:136: in _beartype_object_fatal beartype_nontype(obj, **kwargs) # type: ignore[return-value] /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:182: in beartype_nontype return beartype_func(obj, **kwargs) # type: ignore[return-value] /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:247: in beartype_func func_wrapper_code = generate_code(bear_call) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/wrapmain.py:122: in generate_code code_check_return = _code_check_return(bear_call) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:237: in code_check_return reraise_exception_placeholder( /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/error/utilerrraise.py:138: in reraise_exception_placeholder raise exception.with_traceback(exception.__traceback__) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:174: in code_check_return ) = make_code_raiser_func_pith_check( # type: ignore[assignment] /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached raise exception /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached return_value = args_flat_to_return_value[args_flat] = func( /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/checkmake.py:311: in make_code_raiser_func_pith_check ) = make_check_expr(hint, conf, cls_stack) /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached raise exception /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached return_value = args_flat_to_return_value[args_flat] = func( /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/code/codemake.py:1578: in make_check_expr hint_childs = get_hint_pep484585_args( # type: ignore[assignment] /nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/hint/pep/proposal/pep484585/utilpep484585.py:158: in get_hint_pep484585_args raise BeartypeDecorHintPep585Exception( E beartype.roar.BeartypeDecorHintPep585Exception: Class method gpjax.base.module.check_return() return PEP 585 type hint dict[str] not subscripted (indexed) by 2 arguments (i.e., subscripted by 1 != 2 arguments).
=================================== FAILURES =================================== _________ test_thompson_sampling_non_conjugate_posterior_raises_error __________ args = (ThompsonSampling(num_features=100),) kwargs = {'datasets': {'OBJECTIVE': - Number of observations: 10 - Input dimension: 1}, 'key': Array((), dtype=key<fry>) overla... [-0.35197108], [-0.46512772], [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}} bound = <BoundArguments (self=ThompsonSampling(num_features=100), posteriors={'OBJECTIVE': NonConjugatePosterior(prior=Prior(k...s={'OBJECTIVE': - Number of observations: 10 - Input dimension: 1}, key=Array((), dtype=key<fry>) overlaying: [ 0 42])> memos = ({}, {}, {}, {'datasets': {'OBJECTIVE': - Number of observations: 10 - Input dimension: 1}, 'key': Array((), dtype=key... [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}, 'self': ThompsonSampling(num_features=100)}) argmsg = "\nThe problem arose whilst typechecking parameter 'posteriors'.\nActual value: { 'OBJECTIVE': NonConjugatePosterior(p... key=Array([ 0, 42], dtype=uint32))}\nExpected type: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior]." name = 'build_utility_function' param_values = "{ 'datasets': {'OBJECTIVE': - Number of observations: 10\n- Input dimension: 1},\n 'key': Array((), dtype=key<fry>) ... key=Array([ 0, 42], dtype=uint32))},\n 'self': ThompsonSampling(num_features=100)}" param_hints = "(self, posteriors: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], datasets: collections.abc.Mapping[str, gpjax.dataset.Dataset], key: Union[UInt32[Array, '2'], Key[Array, '']])" msg = "Type-check error whilst checking the parameters of build_utility_function.\nThe problem arose whilst typechecking par...or], datasets: collections.abc.Mapping[str, gpjax.datas et.Dataset], key: Union[UInt32[Array, '2'], Key[Array, '']]).\n" @ft.wraps(fn) def wrapped_fn(*args, **kwargs): if config.jaxtyping_disable: return fn(*args, **kwargs) # Raise bind-time errors before we do any shape analysis. (I.e. skip # the pointless jaxtyping information for a non-typechecking failure.) bound = param_signature.bind(*args, **kwargs) memos = push_shape_memo(bound.arguments) try: # First type-check just the parameters before the function is # called. try: > param_fn(*args, **kwargs) E beartype.roar.BeartypeCallHintParamViolation: Method gpjax.decision_making.utility_functions.thompson_sampling.check_params() parameter posteriors={'OBJECTIVE': NonCo njugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelCompu...))} violates type hint collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], as dict key str 'OBJECTIV E' value <protocol "gpjax.gps.NonConjugatePosterior"> "NonConjugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelComputation(), act...))" not instance of <protocol "gpja x.gps.ConjugatePosterior">. /nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:418: BeartypeCallHintParamViolation
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Fix some incorrect type annotations causing errors when running with the most recent version of beartype (0.18.5).