alan-turing-institute / AIrsenal

Machine learning Fantasy Premier League team
MIT License
289 stars 86 forks source link

`NumpyroPlayerModel` doesn't work after updating `numpyro` #611

Open jack89roberts opened 1 year ago

jack89roberts commented 1 year ago

test_get_fitted_player_model_numpyro marked as an xfail.

Error is:

./airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro Failed: [undefined]RuntimeError: Cannot find valid initial parameters. Please check your model again.
def test_get_fitted_player_model_numpyro():
        pm = NumpyroPlayerModel()
        assert isinstance(pm, NumpyroPlayerModel)
        with test_past_data_session_scope() as ts:
>           fpm = fit_player_data("FWD", "1819", 12, model=pm, dbsession=ts)

airsenal/tests/test_score_predictions.py:269: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
airsenal/framework/prediction_utils.py:525: in fit_player_data
    fitted_model = model.fit(data)
airsenal/framework/player_model.py:177: in fit
    mcmc.run(
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:628: in run
    states_flat, last_state = partial_map_fn(map_args)
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:410: in _single_chain_mcmc
    new_init_state = self.sampler.init(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:713: in init
    init_params = self._init_state(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:657: in _init_state
    ) = initialize_model(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

rng_key = Array([3923418436, 1366451097], dtype=uint32)
model = <numpyro.handlers.substitute object at 0x2880dec50>

    def initialize_model(
        rng_key,
        model,
        *,
        init_strategy=init_to_uniform,
        dynamic_args=False,
        model_args=(),
        model_kwargs=None,
        forward_mode_differentiation=False,
        validate_grad=True,
    ):
        """
        (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
        and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
        to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

        :param jax.random.PRNGKey rng_key: random number generator seed to
            sample from the prior. The returned `init_params` will have the
            batch shape ``rng_key.shape[:-1]``.
        :param model: Python callable containing Pyro primitives.
        :param callable init_strategy: a per-site initialization function.
            See :ref:`init_strategy` section for available functions.
        :param bool dynamic_args: if `True`, the `potential_fn` and
            `constraints_fn` are themselves dependent on model arguments.
            When provided a `*model_args, **model_kwargs`, they return
            `potential_fn` and `constraints_fn` callables, respectively.
        :param tuple model_args: args provided to the model.
        :param dict model_kwargs: kwargs provided to the model.
        :param bool forward_mode_differentiation: whether to use forward-mode differentiation
            or reverse-mode differentiation. By default, we use reverse mode but the forward
            mode can be useful in some cases to improve the performance. In addition, some
            control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
            only supports forward-mode differentiation. See
            `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
            for more information.
        :param bool validate_grad: whether to validate gradient of the initial params.
            Defaults to True.
        :return: a namedtupe `ModelInfo` which contains the fields
            (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
            `param_info` is a namedtuple `ParamInfo` containing values from the prior
            used to initiate MCMC, their corresponding potential energy, and their gradients;
            `postprocess_fn` is a callable that uses inverse transforms
            to convert unconstrained HMC samples to constrained values that
            lie within the site's support, in addition to returning values
            at `deterministic` sites in the model.
        """
        model_kwargs = {} if model_kwargs is None else model_kwargs
        substituted_model = substitute(
            seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
            substitute_fn=init_strategy,
        )
        (
            inv_transforms,
            replay_model,
            has_enumerate_support,
            model_trace,
        ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
        # substitute param sites from model_trace to model so
        # we don't need to generate again parameters of `numpyro.module`
        model = substitute(
            model,
            data={
                k: site["value"]
                for k, site in model_trace.items()
                if site["type"] in ["param"]
            },
        )
        constrained_values = {
            k: v["value"]
            for k, v in model_trace.items()
            if v["type"] == "sample"
            and not v["is_observed"]
            and not v["fn"].support.is_discrete
        }

        if has_enumerate_support:
            from numpyro.contrib.funsor import config_enumerate, enum

            if not isinstance(model, enum):
                max_plate_nesting = _guess_max_plate_nesting(model_trace)
                _validate_model(model_trace, plate_warning="error")
                model = enum(config_enumerate(model), -max_plate_nesting - 1)
        else:
            _validate_model(model_trace, plate_warning="loose")

        potential_fn, postprocess_fn = get_potential_fn(
            model,
            inv_transforms,
            replay_model=replay_model,
            enum=has_enumerate_support,
            dynamic_args=dynamic_args,
            model_args=model_args,
            model_kwargs=model_kwargs,
        )

        init_strategy = (
            init_strategy if isinstance(init_strategy, partial) else init_strategy()
        )
        if (init_strategy.func is init_to_value) and not replay_model:
            init_values = init_strategy.keywords.get("values")
            unconstrained_values = transform_fn(inv_transforms, init_values, invert=True)
            init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
        prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
        (init_params, pe, grad), is_valid = find_valid_initial_params(
            rng_key,
            substitute(
                model,
                data={
                    k: site["value"]
                    for k, site in model_trace.items()
                    if site["type"] in ["plate"]
                },
            ),
            init_strategy=init_strategy,
            enum=has_enumerate_support,
            model_args=model_args,
            model_kwargs=model_kwargs,
            prototype_params=prototype_params,
            forward_mode_differentiation=forward_mode_differentiation,
            validate_grad=validate_grad,
        )

        if not_jax_tracer(is_valid):
            if device_get(~jnp.all(is_valid)):
                with numpyro.validation_enabled(), trace() as tr:
                    # validate parameters
                    substituted_model(*model_args, **model_kwargs)
                    # validate values
                    for site in tr.values():
                        if site["type"] == "sample":
                            with warnings.catch_warnings(record=True) as ws:
                                site["fn"]._validate_sample(site["value"])
                            if len(ws) > 0:
                                for w in ws:
                                    # at site information to the warning message
                                    w.message.args = (
                                        "Site {}: {}".format(
                                            site["name"], w.message.args[0]
                                        ),
                                    ) + w.message.args[1:]
                                    warnings.showwarning(
                                        w.message,
                                        w.category,
                                        w.filename,
                                        w.lineno,
                                        file=w.file,
                                        line=w.line,
                                    )
>               raise RuntimeError(
                    "Cannot find valid initial parameters. Please check your model again."
                )
E               RuntimeError: Cannot find valid initial parameters. Please check your model again.

.venv/lib/python3.11/site-packages/numpyro/infer/util.py:745: RuntimeError
jack89roberts commented 1 year ago

Also:

airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro
  /Users/jroberts/GitHub/AIrsenal/airsenal/framework/player_model.py:177: UserWarning: Site obs: Out-of-support values provided to log prob method. The value argument should be within the support.
    mcmc.run(