TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

Component Arrays with DenseEuclideanMetric fails #344

Open erathorn opened 1 year ago

erathorn commented 1 year ago

Dear Team,

Thak you for your amazing package!

I tried to play around with your wonderful package. However, I noticed that using ComponentArrays and a DenseEuclideanMetric does not work. Here is a quote from the tests.

using AdvancedHMC, ComponentArrays, ForwardDiff

p1 = ComponentVector(μ = 2.0, σ = 1)
struct DemoProblemComponentArrays end

function LogDensityProblems.logdensity(::DemoProblemComponentArrays, p::ComponentArray)
    return -((1 - p.μ) / p.σ)^2
end
LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2
LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) =
    LogDensityProblems.LogDensityOrder{0}()

ℓπ = DemoProblemComponentArrays()

# Define a Hamiltonian system
D = length(p1)          # number of parameters
metric = DenseEuclideanMetric(D) # !!! IN THE ORIGINAL TESTS, THIS IS DiagEuclideanMetric !!!

# choose AD framework or provide a function manually
hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff); x = p1)

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, p1)

This code fails with:

ERROR: MethodError: no method matching AdvancedHMC.PhasePoint(::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(μ = 1, σ = 2)}}}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(μ = 1, σ = 2)}}}, ::AdvancedHMC.DualValue{Float64, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(μ = 1, σ = 2)}}}}, ::AdvancedHMC.DualValue{Float64, Vector{Float64}})
Closest candidates are:
  AdvancedHMC.PhasePoint(::T, ::T, ::V, ::V) where {T, V} at ~/JULIAPACKAGES/103/packages/AdvancedHMC/jiCaS/src/hamiltonian.jl:53

I think the problem is here: https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/hamiltonian.jl#L45C4-L45C4

The result of this operation is a Vector and not a ComponentVector. Probably calling safe_rsimilar (https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/hamiltonian.jl#L94) after the multiplication might solve the issue.

If you agree, that this is the solution, I am happy to supply a PR.

yebai commented 1 year ago

@erathorn Thanks for the kind words. Of course, a PR with some tests is welcome.