bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

Convenient function to access inference methods and kwargs #795

Closed GStechschulte closed 7 months ago

GStechschulte commented 7 months ago

Closes #791. This PR adds a convenient class InferenceMethods that allows users to access the available inference methods and kwargs.

For example, the inference methods

bmb.inference_methods.names
{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
 'bayeux': {'mcmc': ['tfp_hmc',
   'tfp_nuts',
   'tfp_snaper_hmc',
   'blackjax_hmc',
   'blackjax_chees_hmc',
   'blackjax_meads_hmc',
   'blackjax_nuts',
   'blackjax_hmc_pathfinder',
   'blackjax_nuts_pathfinder',
   'flowmc_rqspline_hmc',
   'flowmc_rqspline_mala',
   'flowmc_realnvp_hmc',
   'flowmc_realnvp_mala',
   'numpyro_hmc',
   'numpyro_nuts']}}

and the default kwargs for a given inference method

bmb.inference_methods.get_kwargs("tfp_nuts")
{'extra_parameters': {'num_draws': 1000,
  'num_chains': 8,
  'num_adaptation_steps': 500,
  'return_pytree': False},
 'dual_averaging_kwargs': {'target_accept_prob': 0.8,
  'exploration_shrinkage': 0.05,
  'shrinkage_target': None,
  'step_count_smoothing': 10,
  'decay_rate': 0.75,
  'step_size_setter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_step_size_setter_fn(kernel_results, new_step_size)>,
  'step_size_getter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_step_size_getter_fn(kernel_results)>,
  'log_accept_prob_getter_fn': <function tensorflow_probability.substrates.jax.mcmc.simple_step_size_adaptation.hmc_like_log_accept_prob_getter_fn(kernel_results)>,
  'reduce_fn': <function tensorflow_probability.substrates.jax.math.generic.reduce_log_harmonic_mean_exp(input_tensor, axis=None, keepdims=False, experimental_named_axis=None, experimental_allow_all_gather=False, name=None)>,
  'experimental_reduce_chain_axis_names': None,
  'validate_args': False,
  'name': None,
  'num_adaptation_steps': 500},
 'proposal_kernel_kwargs': {'max_tree_depth': 10,
  'max_energy_diff': 1000.0,
  'unrolled_leapfrog_steps': 1,
  'parallel_iterations': 10,
  'experimental_shard_axis_names': None,
  'name': None,
  'step_size': 0.5}}

Additionally, this convenience class is now imported and used in backend/pymc.py to obtain the bayeux and pymc inference methods. I have updated relevant doc strings and the alternative samplers notebook as well.

review-notebook-app[bot] commented 7 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

ColCarroll commented 7 months ago

Note https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/shared.py#L115 is what I use in bayeux, though I do a fair amount of manual work cleaning things up, or removing arguments that are supplied elsewhere.

If you run the slightly modified

def get_default_signature(fn):
  defaults = {}
  for key, val in inspect.signature(fn).parameters.items():
    if val.default is not inspect.Signature.empty:
      defaults[key] = val.default
  return defaults

on pm.sample, you get the pleasant

{'draws': 1000,
 'tune': 1000,
 'chains': None,
 'cores': None,
 'random_seed': None,
 'progressbar': True,
 'step': None,
 'nuts_sampler': 'pymc',
 'initvals': None,
 'init': 'auto',
 'jitter_max_retries': 10,
 'n_init': 200000,
 'trace': None,
 'discard_tuned_samples': True,
 'compute_convergence_checks': True,
 'keep_warning_stat': False,
 'return_inferencedata': True,
 'idata_kwargs': None,
 'nuts_sampler_kwargs': None,
 'callback': None,
 'mp_ctx': None,
 'model': None}
tomicapretto commented 7 months ago

As far as I know the signature for pm.sample() has arguments for many different things. Maybe we can hard-code the subset of parameters we want to query from it and only report those?

GStechschulte commented 7 months ago

I am not so sure tests should be added for this? The .get_kwargs method already raises an error if the user passes an inference method that is not in the list of available methods.

Then, for bmb.inference_methods.name I suppose a test could be added to assert specific key names (mcmc, vi) exist in the dict?

tomicapretto commented 7 months ago

@GStechschulte I see what you mean. I don't have a strong opinion here. The only thing I can add is that if we leave it untested it'll decrease coverage. I know high coverage does not mean our test suite is perfect, but I do think that in general lower coverage is worse. We could omit the inference_methods.py module from coverage but I'm not sure if it is a good thing or not.

Another option would be to merge as it is and open an issue so someone tests this in the future (as it's not critical).

GStechschulte commented 7 months ago

Many thanks for the review @ColCarroll and @tomicapretto

I added a small test to check that the keys (mcmc, vi) exist when calling bmb.inference_methods.names. As well as a test to ensure that a ValueError is raised if a user passes an unsupported inference method name.