TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

In sample progress=true prints more than a progress-bar #364

Open sebapersson opened 7 months ago

sebapersson commented 7 months ago

Thanks for this package. I tried the example in the README with progress=true (on Julia 1.10 and AdvancedHMC v0.6.1):

using AdvancedHMC, ForwardDiff
using LogDensityProblems
using LinearAlgebra

# Define the target distribution using the `LogDensityProblem` interface
struct LogTargetDensity
    dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2  # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}()

# Choose parameter dimensionality and initial parameter value
D = 10; initial_θ = rand(D)
ℓπ = LogTargetDensity(D)

# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000

# Define a Hamiltonian system
metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

# Define a leapfrog solver, with the initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)

# Define an HMC sampler with the following components
#   - multinomial sampling scheme,
#   - generalised No-U-Turn criteria, and
#   - windowed adaption for step-size and diagonal mass matrix
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
#   - `samples` will store the samples
#   - `stats` will store diagnostic statistics for each sample
samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=true, verbose=false)  

This has a progress-bar, and in the addition the following information is printed:

  iterations:                                   1272
  ratio_divergent_transitions:                  0.0
  ratio_divergent_transitions_during_adaption:  0.0
  n_steps:                                      7
  is_accept:                                    true
  acceptance_rate:                              0.7875153541240331
  log_density:                                  -7.884029732795562
  hamiltonian_energy:                           12.447784679126514
  hamiltonian_energy_error:                     0.3198459860125489
  max_hamiltonian_energy_error:                 0.4934198534404306
  tree_depth:                                   3
  numerical_error:                              false
  step_size:                                    0.6715960136781923

Is this a feature, and is there anyway to only print the progressbar?

torfjelde commented 6 months ago

This is possible using the AbstractMCMC interface with a different callback keyword argument.

To do it propertly, one should just implement a callback similar to

https://github.com/TuringLang/AdvancedHMC.jl/blob/37f0995ea466ec5a9c9b09ecacc04c317ba77886/src/abstractmcmc.jl#L175-L194

https://github.com/TuringLang/AdvancedHMC.jl/blob/37f0995ea466ec5a9c9b09ecacc04c317ba77886/src/abstractmcmc.jl#L203

which only shows the progressbar without all the extra information.

Buuut because in the AbstractMCMC-interface impl we only use the HMCProgressCallback if no callback is manually specified, e.g.

https://github.com/TuringLang/AdvancedHMC.jl/blob/37f0995ea466ec5a9c9b09ecacc04c317ba77886/src/abstractmcmc.jl#L82-L85

we can just provide a callback that does nothing and set progress=true, which then hits the default progress meter in AbstractMCMC.sample, which is exactly what you want.

In short, try the following:

# HACK: a callback that does nothing to avoid hitting `HMCProgressCallback`
callback(rng, model, sampler, transition, state, i) = nothing

samples = AbstractMCMC.sample(
    model,
    sampler,
    n_adapts + n_samples;
    nadapts = n_adapts,
    initial_params = initial_θ,
    callback=callback,
    progress=true,  # will use default progress meter i `AbstractMCMC`
)