Open AlexAndorra opened 1 week ago
@AlexAndorra I think this is already handled? I have not looked deeply into the details, but have a look at this example:
https://bambinos.github.io/bambi/notebooks/alternative_samplers.html#blackjax
Yeah I saw that @tomicapretto , but it doesn't seem to work:
data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)
kwargs = {
"draws": 40,
"chains": 2,
"cores": 3,
}
results = model.fit(inference_method="nutpie", **kwargs)
results.posterior
will still give you 8 chains and 1000 draws
Interestingly, the blackjax nuts example from the NB errors out:
ValueError: not enough values to unpack (expected 2, got 1)
Interestingly, the blackjax nuts example from the NB errors out:
ValueError: not enough values to unpack (expected 2, got 1)
There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.
These are the dependencies I've pinned
Yeah I saw that @tomicapretto , but it doesn't seem to work:
data = bmb.load_data("sleepstudy") model = bmb.Model('Reaction ~ Days', data) kwargs = { "draws": 40, "chains": 2, "cores": 3, } results = model.fit(inference_method="nutpie", **kwargs) results.posterior
will still give you 8 chains and 1000 draws
Interesting, I'll double check what's happening
There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.
Ooooh, I definitely need to do that on my branch then! Shouldn't we merge that into main
while it's an issue?
Maybe @ColCarroll can help with the Bayeux release?
Interesting, I'll double check what's happening
I think this is ignored silently by Bayeux because not passed explicitely. The changes I've done in this PR solve it, but may not cover all the cases. They will though, once I add Colin's suggestion from above
@AlexAndorra @GStechschulte @AlexAndorra
Is it possible the difference is related to where those parameters are stored? See the difference between the parameters for BlackJAX NUTS and Nutpie
bmb.inference_methods.get_kwargs("blackjax_nuts")
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x7f6b3ea3c400>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.8,
'progress_bar': False,
'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>,
'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x7f6b3ea1e5c0>, init=<function init at 0x7f6b3e9eb1a0>, build_kernel=<function build_kernel at 0x7f6b3ebaac00>)},
'adapt.run': {'num_steps': 500},
<function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'step_size': 0.5},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 8,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
The parameters num_chains
, num_draws
, etc. are part of the "extra_parameters"
element.
While in nutpie, see:
bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
'expanded_shapes': [(1,)],
'expanded_names': ['x'],
'expanded_dtypes': [numpy.float64]},
<function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
'tune': 300,
'chains': 8,
'cores': 8,
'seed': None,
'save_warmup': True,
'progress_bar': True,
'low_rank_modified_mass_matrix': False,
'init_mean': None,
'return_raw_trace': False,
'blocking': True,
'progress_template': None,
'progress_style': None,
'progress_rate': 100},
'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
'unflatten': <jax._src.util.HashablePartial at 0x7f1ee36cc1d0>,
'return_pytree': False}}
They are parameters of the nutpie.sample.sample()
function.
I think so @tomicapretto , because then Bambi has to pass them explicitely to bx.sample
when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?
@AlexAndorra @GStechschulte @AlexAndorra
Is it possible the difference is related to where those parameters are stored? See the difference between the parameters for BlackJAX NUTS and Nutpie
bmb.inference_methods.get_kwargs("blackjax_nuts")
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x7f6b3ea3c400>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>, 'is_mass_matrix_diagonal': True, 'initial_step_size': 1.0, 'target_acceptance_rate': 0.8, 'progress_bar': False, 'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>, 'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x7f6b3ea1e5c0>, init=<function init at 0x7f6b3e9eb1a0>, build_kernel=<function build_kernel at 0x7f6b3ebaac00>)}, 'adapt.run': {'num_steps': 500}, <function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10, 'divergence_threshold': 1000, 'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>, 'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>, 'step_size': 0.5}, 'extra_parameters': {'chain_method': 'vectorized', 'num_chains': 8, 'num_draws': 500, 'num_adapt_draws': 500, 'return_pytree': False}}
The parameters
num_chains
,num_draws
, etc. are part of the"extra_parameters"
element.While in nutpie, see:
bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1, 'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>, 'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>, 'expanded_shapes': [(1,)], 'expanded_names': ['x'], 'expanded_dtypes': [numpy.float64]}, <function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000, 'tune': 300, 'chains': 8, 'cores': 8, 'seed': None, 'save_warmup': True, 'progress_bar': True, 'low_rank_modified_mass_matrix': False, 'init_mean': None, 'return_raw_trace': False, 'blocking': True, 'progress_template': None, 'progress_style': None, 'progress_rate': 100}, 'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>, 'unflatten': <jax._src.util.HashablePartial at 0x7f1ee36cc1d0>, 'return_pytree': False}}
They are parameters of the
nutpie.sample.sample()
function.
Just found a counter example. Something similar happens with "numpyro_nuts"
and it passes the kwargs correctly:
bmb.inference_methods.get_kwargs("numpyro_nuts")
{numpyro.infer.hmc.NUTS: {'model': None,
'kinetic_fn': None,
'step_size': 1.0,
'inverse_mass_matrix': None,
'adapt_step_size': True,
'adapt_mass_matrix': True,
'dense_mass': False,
'target_accept_prob': 0.8,
'trajectory_length': None,
'max_tree_depth': 10,
'init_strategy': <function numpyro.infer.initialization.init_to_uniform(site=None, radius=2)>,
'find_heuristic_step_size': False,
'forward_mode_differentiation': False,
'regularize_mass_matrix': True},
numpyro.infer.mcmc.MCMC: {'num_chains': 8,
'thinning': 1,
'postprocess_fn': None,
'chain_method': 'vectorized',
'progress_bar': True,
'jit_model_args': False,
'num_warmup': 500,
'num_samples': 1000},
'extra_parameters': {'return_pytree': False}}
kwargs = {
"num_chains": 2,
"num_samples": 250,
"num_adapt_draws": 250
}
blackjax_nuts_idata = model.fit(inference_method="numpyro_nuts", **kwargs)
blackjax_nuts_idata.posterior
I resolved my previous two comments because I realized they were not correct. The reason for kwargs
in
not containing chains
, draws
, tune
, and cores
is because they are also keyword arguments of the Model.fit()
method, so Python does not include them in the kwargs
dictionary.
I see the following alternatives
draws
, tune
, chains
, and cores
to the sampling function. If bayeux just ignores them when they are not expected by the underlying sampler, then it's all good.Model.fit()
to accept arguments passed to PyMC and other samplers as kwargs
. The downside I see here is lack of autocomplete.I am late to the party @tomicapretto. Regarding nutpie, the kwargs returned is not "pretty"
bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
'expanded_shapes': [(1,)],
'expanded_names': ['x'],
'expanded_dtypes': [numpy.float64]},
<function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
'tune': 300,
'chains': 8,
'cores': 8,
'seed': None,
'save_warmup': True,
'progress_bar': True,
'low_rank_modified_mass_matrix': False,
'init_mean': None,
'return_raw_trace': False,
'blocking': True,
'progress_template': None,
'progress_style': None,
'progress_rate': 100},
'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
'unflatten': <jax._src.util.HashablePartial at 0x3299964e0>,
'return_pytree': False}}
It is a nested dictionary, where keys are objects. A nested dictionary isn't a problem per say, e.g passing nested args to Blackjax NUTS.
kwargs = {
"adapt.run": {"num_steps": 500},
"num_chains": 4,
"num_draws": 250,
"num_adapt_draws": 250
}
I think so @tomicapretto , because then Bambi has to pass them explicitely to
bx.sample
when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?
Not sure what you mean with "it works". If it is passing arguments to BlackJAX, it's always worked. If it is the tests, yes, now it works after pinning deps.
@GStechschulte I think that is not the problem (see https://github.com/bambinos/bambi/pull/855#issuecomment-2466301370 and my resolved comments)
I meant passing arguments to Blackjax, because it currently errors out as I showed above. This is solved by pinning Bayeux, IIUC.
For nutpie, I think we can merge these changes (once I add those other cases), but of course I'll let you decide
El El sáb, 9 nov 2024 a la(s) 12:36, Tomás Capretto < @.***> escribió:
@GStechschulte https://github.com/GStechschulte I think that is not the problem (see #855 (comment) https://github.com/bambinos/bambi/pull/855#issuecomment-2466301370 and my resolved comments)
— Reply to this email directly, view it on GitHub https://github.com/bambinos/bambi/pull/855#issuecomment-2466308146, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHIJMTDAY4OFFGW7TPIQ4PTZ7ZBYTAVCNFSM6AAAAABRLZBTUSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINRWGMYDQMJUGY . You are receiving this because you were mentioned.Message ID: @.***>
@AlexAndorra I'm going to incorporate your changes, just modified things a bit. The environment goes under a conda-envs
directory. I'm doing the same I saw here in PyMC https://github.com/pymc-devs/pymc/tree/main/conda-envs.
Could you install from your branch and try to run nutpie passing those kwargs? If tests are OK, and what you run works, then I'll merge.
Currently, some kwargs are not passed to Bayeux when fitting the model. This PR makes sure it does. The only change is on line 270 of
bambi/backend/pymc.py
-- the rest is only formatting.Also added an env file for creation with mamba. Ready for review!