We could add an optimser for the CV estimator. This could be using Adam in JAX's optimizer module, or using Optax. The idea is that there is a good optimiser setup that's easy to use.
the step size could have a default value (1e-3 is usually good). But then it wouldn't be the first argument as in the sgmcmc function (example: build_sgld), so maybe not having it as default it better.
compile would by default be set to True, but the optimiser would be a python loop if compile=False
the optimiser returns the optimal parameters as well as an array of loss values throughout the optimisation
On the other hand, perhaps this could be considered out of the scope of the library?
We could add an optimser for the CV estimator. This could be using Adam in JAX's optimizer module, or using Optax. The idea is that there is a good optimiser setup that's easy to use.
Example API:
1e-3
is usually good). But then it wouldn't be the first argument as in the sgmcmc function (example:build_sgld
), so maybe not having it as default it better.compile
would by default be set toTrue
, but the optimiser would be a python loop ifcompile=False
On the other hand, perhaps this could be considered out of the scope of the library?