pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

Object oriented wrapper API #1832

Closed tomwallis closed 1 month ago

tomwallis commented 1 month ago

TL;DR: would the numpyro devs be interested in a light OO API to Numpyro?


Hello,

In the context of working with different models, I found it unwieldy to handle all the variables that are associated with a model (e.g. the model function itself, the mcmc object, predictive samples, etc). Sampling, prediction, etc led to a lot of code repetition. Therefore I drafted a light object-oriented wrapper so that a user can more easily manage numpyro model objects.

The API looks something like:

rng_key = random.PRNGKey(42)
m1 = ModelOne(
    data=df, rng_key=rng_key, ...
)

where the ModelOne class contains a method model that defines the model. Sampling and prediction can then be performed with e.g.:

m1.sample()
m1.mcmc.print_summary()

prediction can be done with

prior_samples = m1.predict(data=new_df, prior=True, sample_obs=False)

Model comparison can be done with

comp_df = az.compare({"model 1": m1.arviz_data, "model 2": m2.arviz_data})

and so on. You can see a more complete demo of the API here.

Currently, the user would be expected to write a new class that inherits from a BaseNumpyroModel, minimally defining a new model method. An advantage of this approach is that the bulk of the repetitive sampling, prediction, etc code is hidden in the class. Additionally, all the variables and objects associated with a model are contained in the instance namespace (e.g. of m1), so that variable naming becomes a bit simpler and things can be found more easily. Yet the flexibility of writing the model directly is more-or-less maintained.

I find this more convenient to work with. I was thinking about putting this helper wrapper into a new package, but if the numpyro devs are interested, maybe it makes sense to put it into numpyro directly?

fehiepsi commented 1 month ago

Hi @tomwallis, sorry that we might not want to include that wrapper - mainly to keep numpyro at a right level of abstraction. But I agree that api will be very helpful for users.

tomwallis commented 1 month ago

Ok, thanks for the fast reply! I'll let you know if I end up creating a package for this purpose. Naming suggestions welcome.

tomwallis commented 1 month ago

Hi @fehiepsi, the initial release of numpyro-oop is now out. You can find it here.

fehiepsi commented 1 month ago

I just walked through the demo colab. The api is awesome! I would recommend using similar approach for other domains https://github.com/pyro-ppl/numpyro/issues/1361

tomwallis commented 1 month ago

Thanks!