TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
MIT License
157 stars 26 forks source link

Give user an option to use `SimpleVarInfo` with `sample` function #606

Closed sunxd3 closed 1 month ago

sunxd3 commented 4 months ago

Partially address https://github.com/TuringLang/Turing.jl/issues/2213

An example

julia> using AbstractMCMC, DynamicPPL

julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())

julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
 SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)

Work with Turing

This should work with Turing's Inference pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to

function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})

This allows bundle_samples to use this function.


julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))

julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 2.04 seconds
Compute duration  = 2.04 seconds
parameters        = s[1], s[2], m[1], m[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

        s[1]    1.6551    0.9727    0.3219     8.1555    10.0000    1.6048        3.9939
        s[2]    0.9875    0.1954    0.0942     4.5670    10.0000    1.5820        2.2365
        m[1]    0.7053    1.0723    0.4889     5.0832    10.0000    1.4156        2.4893
        m[2]    1.1302    0.5511    0.1743    10.0000    10.0000    0.9406        4.8972

  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        s[1]    0.6404    0.9317    1.5316    2.0851    3.5047
        s[2]    0.7580    0.8095    1.0099    1.1126    1.3119
        m[1]   -0.9806   -0.1058    1.0414    1.3320    2.2166
        m[2]    0.3819    0.7973    1.0316    1.6741    1.8228
sunxd3 commented 4 months ago

An alternative approach I can envision is fully adopting LogDensityFunction in Turing through the ExternalSampler interface, but this might requires much more serious work encapsulating the InferenceAlgorithms. (Do we have plan to do this during the coming months?)

Also for SimpleVarInfo, I opted in OrderedDict. To use NamedTuple, it requires to predetermine the variable names (ref https://github.com/TuringLang/DynamicPPL.jl/blob/48487cc471543030699e3a39776dd939756d0165/src/simple_varinfo.jl#L62-L65), in principle can be done, but need a bit of work. This also means, when SimpleVarInfo through the changes proposed in this PR may be less performant.

sunxd3 commented 4 months ago

@torfjelde @yebai @devmotion does this PR make sense? If this is desirable, then I'll extend tests in https://github.com/TuringLang/DynamicPPL.jl/blob/master/test/sampler.jl.

coveralls commented 4 months ago

Pull Request Test Coverage Report for Build 9092154133


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 2 0.0%
src/sampler.jl 9 15 60.0%
<!-- Total: 9 17 52.94% -->
Files with Coverage Reduction New Missed Lines %
src/sampler.jl 2 84.75%
<!-- Total: 2 -->
Totals Coverage Status
Change from base Build 9062140401: -1.0%
Covered Lines: 2775
Relevant Lines: 3582

💛 - Coveralls
coveralls commented 4 months ago

Pull Request Test Coverage Report for Build 9116442883


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 1 0.0%
src/sampler.jl 17 19 89.47%
<!-- Total: 17 20 85.0% -->
Totals Coverage Status
Change from base Build 9099752668: 0.03%
Covered Lines: 2660
Relevant Lines: 3428

💛 - Coveralls
yebai commented 4 months ago

cc @willtebbutt

yebai commented 3 months ago

@willtebbutt can you help review this PR?

torfjelde commented 3 months ago

Having a look @sunxd3 :+1:

sunxd3 commented 3 months ago

@torfjelde is my understanding of SimpleVarInfo with NamedTuple correct at https://github.com/TuringLang/DynamicPPL.jl/pull/606#issuecomment-2111867290?

sunxd3 commented 3 months ago

we need to make it possible to use the NamedTuple version

I agree.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong), but I am unsure if AbstractDict version of SimpleVarInfo works better than VarInfo.

torfjelde commented 3 months ago

s my understanding of SimpleVarInfo with NamedTuple correct at

Yep! If you don't "seed" SimpleVarInfo{<:NamedTuple} with the correct values, then it will only be sensible for models containing only varnames of the form VarName{sym,typeof(identity)}. Now that we have "debugging" capabilities, we could "check" this, but that would ofc not be 100% reliable.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong)

Not really. Or rather, I don't think VarInfo is particuarly error-prone for AD either. SimpleVarInfo{<:NamedTuple} might be for some of the more recent AD backends, e.g. Tapir.jl and Enzyme.jl, but only because it improves type-stability (which is not hte case of SimpleVarInfo{<:Dict})

sunxd3 commented 3 months ago

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

torfjelde commented 3 months ago

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

Lovely:) And I don't mean to be negative about this btw. I'm just not a huge fan of adding additional kwargs, etc. unless there's a clear reason to use them (because otherwise nobody will ever use them until they suddenly are riddled with uncaught bugs because nobody uses it). So if we're going to add an additional kwarg to use SimpleVarInfo, we should simultaneously prove that it has utility, i.e.:

  1. Make it work with SimpleVarInfo{<:NamedTuple}, because that definitively has utility.
  2. Test it properly in Turing.jl and make sure that indeed SimpleVarInfo is used, and nowhere do we implicitly convert to VarInfo before we merge this into DynamicPPL.jl.
sunxd3 commented 3 months ago

Sounds good. I started this as a quick and dirty prototype, it definitely needs more work, until we can justify complicating the interface.

yebai commented 1 month ago

This is no longer necessary since Tapir now works with all tracing types.