This PR removes the algorithm classes. These are static classes we don't use directly, but via calling new and getting a SamplingAlgorithm result. This PR:
replaces these classes with instances of a single class so now each module contains: an init, build_kernel and a as_sampling_algorithm functions. The idea is that the latter fixes dependencies/parameters, in particular when these are not differentiable. The init and build_kernel still exist and are exposed since we need that lower level API, specially when composing algorithms (for example, when tuning SMC inner kernel we need to be able to change parameter on every call).
By doing this we still can call algorithms directly, like blackjax.hmc(). What we do loose is the (light) type annotations we are doing, for example in window_adaptation. I have been thinking about this type of annotations, and I think we should remove them.
The reason for doing it is the following: python fosters duck and structural typing, in contrast to nominal typing like you can find in say Java. In the case of window adaptation, we want the type to mean "hmc family" as algorithms that have an inverse_mass_matrix and a step size. But the way is implemented right now, it actually means whatever the class hmc or the class nuts does! So the classes kind of exist just to be able to name them (aka to use nominal typing). Since most of our codebase is functional, from a typing perspective most samplers are Callables that take in matrixes, doubles, pytrees and return something of the same flavours. There's no way to say: this is the type of a callable that takes in an inverse_mass_matrix and a step size and uses it in some consistent way, because that is not duck typing nor structural! aka we are trying to statically type using tools that are not pythonic. I'd suggest we replace this kind of "algorithm level" type annotations with docs and tests. Check for example the smc_compatibility_test
This PR removes the algorithm classes. These are static classes we don't use directly, but via calling
new
and getting aSamplingAlgorithm
result. This PR:By doing this we still can call algorithms directly, like blackjax.hmc(). What we do loose is the (light) type annotations we are doing, for example in window_adaptation. I have been thinking about this type of annotations, and I think we should remove them.
The reason for doing it is the following: python fosters duck and structural typing, in contrast to nominal typing like you can find in say Java. In the case of window adaptation, we want the type to mean "hmc family" as algorithms that have an inverse_mass_matrix and a step size. But the way is implemented right now, it actually means whatever the class hmc or the class nuts does! So the classes kind of exist just to be able to name them (aka to use nominal typing). Since most of our codebase is functional, from a typing perspective most samplers are Callables that take in matrixes, doubles, pytrees and return something of the same flavours. There's no way to say: this is the type of a callable that takes in an inverse_mass_matrix and a step size and uses it in some consistent way, because that is not duck typing nor structural! aka we are trying to statically type using tools that are not pythonic. I'd suggest we replace this kind of "algorithm level" type annotations with docs and tests. Check for example the smc_compatibility_test