blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Removal of Algorithm classes. #657

Closed ciguaran closed 5 months ago

ciguaran commented 5 months ago

This PR removes the algorithm classes. These are static classes we don't use directly, but via calling new and getting a SamplingAlgorithm result. This PR:

By doing this we still can call algorithms directly, like blackjax.hmc(). What we do loose is the (light) type annotations we are doing, for example in window_adaptation. I have been thinking about this type of annotations, and I think we should remove them.

The reason for doing it is the following: python fosters duck and structural typing, in contrast to nominal typing like you can find in say Java. In the case of window adaptation, we want the type to mean "hmc family" as algorithms that have an inverse_mass_matrix and a step size. But the way is implemented right now, it actually means whatever the class hmc or the class nuts does! So the classes kind of exist just to be able to name them (aka to use nominal typing). Since most of our codebase is functional, from a typing perspective most samplers are Callables that take in matrixes, doubles, pytrees and return something of the same flavours. There's no way to say: this is the type of a callable that takes in an inverse_mass_matrix and a step size and uses it in some consistent way, because that is not duck typing nor structural! aka we are trying to statically type using tools that are not pythonic. I'd suggest we replace this kind of "algorithm level" type annotations with docs and tests. Check for example the smc_compatibility_test

albcab commented 5 months ago

I have some suggestion around naming. @albcab thoughts?

Suggesting algorithm rather than API, this way we could use as_algorithm for all.