TuringLang / SliceSampling.jl

Slice sampling algorithms in Julia
https://turinglang.org/SliceSampling.jl/
MIT License
6 stars 2 forks source link

Slice Sampling on Turing model returns constant LP values in Chain #15

Open dlakelan opened 1 day ago

dlakelan commented 1 day ago

Rather than returning actual LP values in the return chain, the sampler returns what appears to be the initial values for every sample. Here is a MWE:

using Turing, SliceSampling, StatsPlots

@model function foo()
    a ~ MvNormal(fill(0,3),1.0)
end

sam = sample(foo(),externalsampler(SliceSampling.HitAndRun(SliceSteppingOut(2.))),10,initial_params=fill(10.0,3))

plot(sam["a[1]"])
plot(sam[:lp])

This lp plot is a horizontal line, whereas the samples of a[1] clearly move around and should have different LP values

This is with SliceSampling 0.6.1 and Turing 0.34.1

Red-Portal commented 1 day ago

Hi @torfjelde @mhauru , is there an additional interface I have to implement on my side to fix this, or is it a bug on the Turing side?

torfjelde commented 1 day ago

This honestly looks like a bug on Turing.jl's end; I don't think lp should even be in the chain here 🤷 At least with the current impl, that's not the intention for externalsampler.

dlakelan commented 1 day ago

having the lp values seems really important, one of the easiest ways to diagnose convergence is to see that the lp has converged to stationary and the same region for all chains.

Red-Portal commented 1 day ago

@torfjelde I was guessing that this function will be invoked by Turing to recompute lp whenever necessary, but I guess not?

@dlakelan lp can certainly be used for that purpose, but it doesn't necessarily need to be, and it isn't obviously the best quantity for doing so, no? Is looking at the $\widehat{R}$ of the parameters insufficient?

dlakelan commented 1 day ago

@Red-Portal the lp trace gives you a traceplot, whereas R hat gives you a summary statistic of the entire chain. If the Rhat is not 1 then it doesn't really give you much information about what went wrong. For example, maybe out of 6 chains 5 of them converged LP to one region whereas the 6th got stuck... or maybe if you take all samples after the 100th sample they converged to the same region and the Rhat of that subset is 1 etc.

I find the traceplot of lp much more informative than summary stats.

Red-Portal commented 1 day ago

I find the traceplot of lp much more informative than summary stats.

But wouldn't any trace, like one of the parameters, do the same trick?

dlakelan commented 1 day ago

No, it's entirely possible for some parameters to be fully converged and others to be stuck in a local optimum or wandering around lost. LP depends on ALL the parameters.

Red-Portal commented 1 day ago

Ah you're talking in terms of comparing across chains. Okay yes that makes sense.

torfjelde commented 23 hours ago

having the lp values seems really important

Sure! But these are quantities that can easily be computed after the fact too:)

julia> @model function foo()
           a ~ MvNormal(fill(0,3),1.0)
       end
foo (generic function with 2 methods)

julia> sam = sample(foo(),externalsampler(SliceSampling.HitAndRun(SliceSteppingOut(2.))),10,initial_params=fill(10.0,3))
Sampling 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:01
Chains MCMC chain (10×4×1 Array{Float64, 3}):

Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 1.93 seconds
Compute duration  = 1.93 seconds
parameters        = a[1], a[2], a[3]
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

        a[1]    4.2874    2.7543    0.8710     8.2309    10.0000    1.3090        4.2735
        a[2]    4.8392    5.4982    2.7238     4.2663    10.0000    1.5820        2.2151
        a[3]    5.8212    2.4718    1.0158     6.3101    10.0000    1.3013        3.2763

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        a[1]   -0.0302    3.2358    4.1998    5.2023    9.1064
        a[2]   -3.5720    0.1340    6.1599    8.5229   11.5717
        a[3]    2.9935    3.7541    5.4587    7.8328    9.6795

julia> logjoint(foo(), sam)
10×1 Matrix{Float64}:
 -152.75681559961401
 -124.57841649146756
  -79.9584047071337
  -63.01607272527707
  -37.45955307019793
  -37.33701541789175
  -33.948795464804974
  -31.05996528484012
  -21.954809961421393
  -21.59660608994128

Long-term we definitively want to introduce some convenient way to allow external samplers to save more information in the chains, but right now I think the best way is to just compute these things after the fact.

dlakelan commented 20 hours ago

That's a super useful function to know about thanks for that pointer.

My only concern about after the fact is when the model is quite costly. For example if you have to solve a PDE for a minute to get the LP. I think there's a tendency to have in our mind something like a linear regression and not something like a computational fluid dynamics problem or pharmacokinetics. But there are good reasons to avoid recalculation of LP for some types of models

torfjelde commented 19 hours ago

Definitively:) As I said, we do want to make it possible to have a sampler wrapped in externalsampler provide some information about what information you want to keep around. However, at the moment we're mainly seeing usage of models where an additional evaluation per sample isn't really a big issue, so it's probably not something that will be addressed very rapidly (though it is on our TODO :+1:)