bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.04k stars 119 forks source link

Sampler parameters not used when calling JAX samplers #813

Open GStechschulte opened 1 month ago

GStechschulte commented 1 month ago

To change bayeux based sampler parameter values, we require passing kwargs like

kwargs = {
        "adapt.run": {"num_steps": 500},
        "num_chains": 4,
        "num_draws": 250,
        "num_adapt_draws": 250
}

blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts", **kwargs)

However, if a user attempts to also pass draws and tune

kwargs = {
        "adapt.run": {"num_steps": 500},
        "num_chains": 4,
        "num_draws": 500,
        "num_adapt_draws": 500
}

blackjax_nuts_idata = model.fit(draws=250, tune=250, inference_method="blackjax_nuts", **kwargs)

the values for draws and tune are never used. Should we print to stdout that these values are not used? Or using some other means to make the user aware?

Of course, after sampling one can infer the number of chains, draws, etc. Just a thought.

ColCarroll commented 1 month ago

Maybe just an adapter to map draws -> num_steps and tune -> num_adapt_draws?

I think in the example above, bayeux will ignore the adapt.run kwarg. It is undocumented, but in order to allow switching methods more quickly, bayeux does not currently warn on unused kwargs.