stefanradev93 / BayesFlow

A Python library for amortized Bayesian workflows using generative neural networks.
https://bayesflow.org/
MIT License
297 stars 45 forks source link

Automatic Batching using `vmap` #124

Closed LarsKue closed 2 months ago

LarsKue commented 6 months ago

Classes such as bayesflow.simulation.ContextGenerator or bayesflow.simulation.Prior allow the user to pass in a batched or an unbatched generation function, where typically, passing both is disallowed.

For quality of life (and possibly compatibility with #123) it would make sense to change this to a simple boolean flag, and automatically parallelize unbatched functions for the user with a variation of either JAX's or PyTorch's vmap function.

Example pseudo-code:

Replace this

class Prior:
    def __init__(self, batch_prior_fun=None, prior_fun=None):
        if (batch_prior_fun is None) == (prior_fun is None):
            raise ValueError

        self.prior = prior_fun
        self.batched_prior = batch_prior_fun

with

class Prior:
    def __init__(self, prior_fun, is_batched=True):
        self.prior = prior_fun if is_batched else vmap(prior_fun)

Since this is a breaking change, we would need a soft introduction of this feature including backward compatibility and deprecation messages.

Possible concerns:

LarsKue commented 2 months ago

Implemented with 6565a51a9011554d3374543bdcd454f86878c5d8