bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

Pass kwargs to nutpie + create env.yml file #855

Open AlexAndorra opened 1 week ago

AlexAndorra commented 1 week ago

Currently, some kwargs are not passed to Bayeux when fitting the model. This PR makes sure it does. The only change is on line 270 of bambi/backend/pymc.py -- the rest is only formatting.

Also added an env file for creation with mamba. Ready for review!

tomicapretto commented 1 week ago

@AlexAndorra I think this is already handled? I have not looked deeply into the details, but have a look at this example:

https://bambinos.github.io/bambi/notebooks/alternative_samplers.html#blackjax

AlexAndorra commented 1 week ago

Yeah I saw that @tomicapretto , but it doesn't seem to work:

data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)
kwargs = {
    "draws": 40,
    "chains": 2,
    "cores": 3,
}
results = model.fit(inference_method="nutpie", **kwargs)
results.posterior

will still give you 8 chains and 1000 draws

AlexAndorra commented 1 week ago

Interestingly, the blackjax nuts example from the NB errors out:

ValueError: not enough values to unpack (expected 2, got 1)
tomicapretto commented 1 week ago

Interestingly, the blackjax nuts example from the NB errors out:

ValueError: not enough values to unpack (expected 2, got 1)

There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.

These are the dependencies I've pinned

https://github.com/bambinos/bambi/blob/7a18fb9afc5b485dcd95f1a421bbd77586106a2f/pyproject.toml#L41-L46

tomicapretto commented 1 week ago

Yeah I saw that @tomicapretto , but it doesn't seem to work:

data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)
kwargs = {
    "draws": 40,
    "chains": 2,
    "cores": 3,
}
results = model.fit(inference_method="nutpie", **kwargs)
results.posterior

will still give you 8 chains and 1000 draws

Interesting, I'll double check what's happening

AlexAndorra commented 1 week ago

There is currently a problem with the dependencies. I just pinned them in a separate PR because it was being problematic. I think we need a new release of bayeux. Unforunately, I'm not familiar enough with it to work on it.

Ooooh, I definitely need to do that on my branch then! Shouldn't we merge that into main while it's an issue? Maybe @ColCarroll can help with the Bayeux release?

Interesting, I'll double check what's happening

I think this is ignored silently by Bayeux because not passed explicitely. The changes I've done in this PR solve it, but may not cover all the cases. They will though, once I add Colin's suggestion from above

tomicapretto commented 1 week ago

@AlexAndorra @GStechschulte @AlexAndorra

Is it possible the difference is related to where those parameters are stored? See the difference between the parameters for BlackJAX NUTS and Nutpie

bmb.inference_methods.get_kwargs("blackjax_nuts")
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x7f6b3ea3c400>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'is_mass_matrix_diagonal': True,
  'initial_step_size': 1.0,
  'target_acceptance_rate': 0.8,
  'progress_bar': False,
  'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>,
  'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x7f6b3ea1e5c0>, init=<function init at 0x7f6b3e9eb1a0>, build_kernel=<function build_kernel at 0x7f6b3ebaac00>)},
 'adapt.run': {'num_steps': 500},
 <function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,
  'divergence_threshold': 1000,
  'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
  'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'step_size': 0.5},
 'extra_parameters': {'chain_method': 'vectorized',
  'num_chains': 8,
  'num_draws': 500,
  'num_adapt_draws': 500,
  'return_pytree': False}}

The parameters num_chains, num_draws, etc. are part of the "extra_parameters" element.

While in nutpie, see:

bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
  'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
  'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
  'expanded_shapes': [(1,)],
  'expanded_names': ['x'],
  'expanded_dtypes': [numpy.float64]},
 <function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
  'tune': 300,
  'chains': 8,
  'cores': 8,
  'seed': None,
  'save_warmup': True,
  'progress_bar': True,
  'low_rank_modified_mass_matrix': False,
  'init_mean': None,
  'return_raw_trace': False,
  'blocking': True,
  'progress_template': None,
  'progress_style': None,
  'progress_rate': 100},
 'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
  'unflatten': <jax._src.util.HashablePartial at 0x7f1ee36cc1d0>,
  'return_pytree': False}}

They are parameters of the nutpie.sample.sample() function.

AlexAndorra commented 1 week ago

I think so @tomicapretto , because then Bambi has to pass them explicitely to bx.sample when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?

tomicapretto commented 1 week ago

@AlexAndorra @GStechschulte @AlexAndorra

Is it possible the difference is related to where those parameters are stored? See the difference between the parameters for BlackJAX NUTS and Nutpie

bmb.inference_methods.get_kwargs("blackjax_nuts")
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x7f6b3ea3c400>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'is_mass_matrix_diagonal': True,
  'initial_step_size': 1.0,
  'target_acceptance_rate': 0.8,
  'progress_bar': False,
  'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>,
  'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x7f6b3ea1e5c0>, init=<function init at 0x7f6b3e9eb1a0>, build_kernel=<function build_kernel at 0x7f6b3ebaac00>)},
 'adapt.run': {'num_steps': 500},
 <function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f6b3ea044a0>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,
  'divergence_threshold': 1000,
  'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
  'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'step_size': 0.5},
 'extra_parameters': {'chain_method': 'vectorized',
  'num_chains': 8,
  'num_draws': 500,
  'num_adapt_draws': 500,
  'return_pytree': False}}

The parameters num_chains, num_draws, etc. are part of the "extra_parameters" element.

While in nutpie, see:

bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
  'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
  'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
  'expanded_shapes': [(1,)],
  'expanded_names': ['x'],
  'expanded_dtypes': [numpy.float64]},
 <function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
  'tune': 300,
  'chains': 8,
  'cores': 8,
  'seed': None,
  'save_warmup': True,
  'progress_bar': True,
  'low_rank_modified_mass_matrix': False,
  'init_mean': None,
  'return_raw_trace': False,
  'blocking': True,
  'progress_template': None,
  'progress_style': None,
  'progress_rate': 100},
 'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
  'unflatten': <jax._src.util.HashablePartial at 0x7f1ee36cc1d0>,
  'return_pytree': False}}

They are parameters of the nutpie.sample.sample() function.

Just found a counter example. Something similar happens with "numpyro_nuts" and it passes the kwargs correctly:

bmb.inference_methods.get_kwargs("numpyro_nuts")
{numpyro.infer.hmc.NUTS: {'model': None,
  'kinetic_fn': None,
  'step_size': 1.0,
  'inverse_mass_matrix': None,
  'adapt_step_size': True,
  'adapt_mass_matrix': True,
  'dense_mass': False,
  'target_accept_prob': 0.8,
  'trajectory_length': None,
  'max_tree_depth': 10,
  'init_strategy': <function numpyro.infer.initialization.init_to_uniform(site=None, radius=2)>,
  'find_heuristic_step_size': False,
  'forward_mode_differentiation': False,
  'regularize_mass_matrix': True},
 numpyro.infer.mcmc.MCMC: {'num_chains': 8,
  'thinning': 1,
  'postprocess_fn': None,
  'chain_method': 'vectorized',
  'progress_bar': True,
  'jit_model_args': False,
  'num_warmup': 500,
  'num_samples': 1000},
 'extra_parameters': {'return_pytree': False}}
kwargs = {
    "num_chains": 2,
    "num_samples": 250,
    "num_adapt_draws": 250
}

blackjax_nuts_idata = model.fit(inference_method="numpyro_nuts", **kwargs)
blackjax_nuts_idata.posterior

image

tomicapretto commented 1 week ago

I resolved my previous two comments because I realized they were not correct. The reason for kwargs in

https://github.com/bambinos/bambi/blob/7a18fb9afc5b485dcd95f1a421bbd77586106a2f/bambi/backend/pymc.py#L261

not containing chains, draws, tune, and cores is because they are also keyword arguments of the Model.fit() method, so Python does not include them in the kwargs dictionary.

tomicapretto commented 1 week ago

I see the following alternatives

  1. Accept the changes proposed by @AlexAndorra, that always pass draws, tune, chains, and cores to the sampling function. If bayeux just ignores them when they are not expected by the underlying sampler, then it's all good.
  2. Change the signature of Model.fit() to accept arguments passed to PyMC and other samplers as kwargs. The downside I see here is lack of autocomplete.
GStechschulte commented 1 week ago

I am late to the party @tomicapretto. Regarding nutpie, the kwargs returned is not "pretty"

bmb.inference_methods.get_kwargs("nutpie")
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
  'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
  'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
  'expanded_shapes': [(1,)],
  'expanded_names': ['x'],
  'expanded_dtypes': [numpy.float64]},
 <function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
  'tune': 300,
  'chains': 8,
  'cores': 8,
  'seed': None,
  'save_warmup': True,
  'progress_bar': True,
  'low_rank_modified_mass_matrix': False,
  'init_mean': None,
  'return_raw_trace': False,
  'blocking': True,
  'progress_template': None,
  'progress_style': None,
  'progress_rate': 100},
 'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
  'unflatten': <jax._src.util.HashablePartial at 0x3299964e0>,
  'return_pytree': False}}

It is a nested dictionary, where keys are objects. A nested dictionary isn't a problem per say, e.g passing nested args to Blackjax NUTS.

kwargs = {
    "adapt.run": {"num_steps": 500},
    "num_chains": 4,
    "num_draws": 250,
    "num_adapt_draws": 250
}
tomicapretto commented 1 week ago

I think so @tomicapretto , because then Bambi has to pass them explicitely to bx.sample when using nutpie. Do you confirm it works with BlackJAX once dependencies are pinned?

Not sure what you mean with "it works". If it is passing arguments to BlackJAX, it's always worked. If it is the tests, yes, now it works after pinning deps.

tomicapretto commented 1 week ago

@GStechschulte I think that is not the problem (see https://github.com/bambinos/bambi/pull/855#issuecomment-2466301370 and my resolved comments)

AlexAndorra commented 1 week ago

I meant passing arguments to Blackjax, because it currently errors out as I showed above. This is solved by pinning Bayeux, IIUC.

For nutpie, I think we can merge these changes (once I add those other cases), but of course I'll let you decide

El El sáb, 9 nov 2024 a la(s) 12:36, Tomás Capretto < @.***> escribió:

@GStechschulte https://github.com/GStechschulte I think that is not the problem (see #855 (comment) https://github.com/bambinos/bambi/pull/855#issuecomment-2466301370 and my resolved comments)

— Reply to this email directly, view it on GitHub https://github.com/bambinos/bambi/pull/855#issuecomment-2466308146, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHIJMTDAY4OFFGW7TPIQ4PTZ7ZBYTAVCNFSM6AAAAABRLZBTUSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINRWGMYDQMJUGY . You are receiving this because you were mentioned.Message ID: @.***>

tomicapretto commented 1 week ago

@AlexAndorra I'm going to incorporate your changes, just modified things a bit. The environment goes under a conda-envs directory. I'm doing the same I saw here in PyMC https://github.com/pymc-devs/pymc/tree/main/conda-envs.

Could you install from your branch and try to run nutpie passing those kwargs? If tests are OK, and what you run works, then I'll merge.