probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.8k stars 160 forks source link

Throw informative error for `elliptical_slice` applied to scalar parameter. #240

Open ali-ramadhan opened 4 years ago

ali-ramadhan commented 4 years ago

Please let me know if I'm misusing Gen but the elliptical_slice docstring and elliptical slice smoke test suggest that mu should be a vector and cov should be a matrix but it errors when applied to a scalar parameter where mu is a 1-element array and cov is a 1x1 array.

I suspect this is because https://github.com/probcomp/Gen/blob/a24877a703b9944e1a115a67270716b35fe7e8b7/src/inference/elliptical_slice.jl#L19 returns a 1-element array instead of a scalar as expected. Maybe it should be changed to something like

if length(mu) == 1
    nu = normal(mu, cov)
else
    nu = mvnormal(zeros(length(mu)), cov)
end

and accept scalar μ, Σ or could have normal also work for multivariate normal but this might be a big breaking change.


Minimal working example:

using LinearAlgebra, Gen

@gen function foo()
    x ~ normal(0, 1)
    y ~ gamma(1, 1)
end

trace, _ = generate(foo, (), choicemap((:y, 0.5)))
μ, Σ = zeros(1), Matrix{Float64}(I(1))
elliptical_slice(trace, :x, μ, Σ)

produces

ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
  convert(::Type{R}, ::T) where {R<:Real, T<:ReverseDiff.TrackedReal} at /home/alir/.julia/packages/ReverseDiff/SCRbd/src/tracked.jl:251
  convert(::Type{T}, ::T) where T<:Number at number.jl:6
  convert(::Type{T}, ::Number) where T<:Number at number.jl:7
  ...
Stacktrace:
 [1] traceat(::Gen.GFUpdateState, ::Gen.Normal, ::Tuple{Int64,Int64}, ::Symbol) at /home/alir/.julia/packages/Gen/8CuZM/src/dynamic/update.jl:47
 [2] ##foo#2586(::Gen.GFUpdateState) at ./REPL[9]:2
 [3] exec(::DynamicDSLFunction{Any}, ::Gen.GFUpdateState, ::Tuple{}) at /home/alir/.julia/packages/Gen/8CuZM/src/dynamic/dynamic.jl:54
 [4] update(::Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}, ::Tuple{}, ::Tuple{}, ::DynamicChoiceMap) at /home/alir/.julia/packages/Gen/8CuZM/src/dynamic/update.jl:186
 [5] #elliptical_slice#272(::Bool, ::EmptyChoiceMap, ::typeof(elliptical_slice), ::Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}, ::Symbol, ::Array{Float64,1}, ::Array{Float64,2}) at /home/alir/.julia/packages/Gen/8CuZM/src/inference/elliptical_slice.jl:32
 [6] elliptical_slice(::Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}, ::Symbol, ::Array{Float64,1}, ::Array{Float64,2}) at /home/alir/.julia/packages/Gen/8CuZM/src/inference/elliptical_slice.jl:15
 [7] top-level scope at REPL[15]:1
ali-ramadhan commented 4 years ago

With this patch

diff --git a/src/inference/elliptical_slice.jl b/src/inference/elliptical_slice.jl
index 3dc2b72..1c330ad 100644
--- a/src/inference/elliptical_slice.jl
+++ b/src/inference/elliptical_slice.jl
@@ -16,7 +16,11 @@ function elliptical_slice(
     argdiffs = map((_) -> NoChange(), args)

     # sample nu
-    nu = mvnormal(zeros(length(mu)), cov)
+    if length(mu) == 1
+        nu = normal(mu, cov)
+    else
+        nu = mvnormal(zeros(length(mu)), cov)
+    end

     # sample u
     u = uniform(0, 1)

this modified minimal working example work

using LinearAlgebra, Gen

@gen function foo()
    x ~ normal(0, 1)
    y ~ gamma(1, 1)
end

trace, _ = generate(foo, (), choicemap((:y, 0.5)))
elliptical_slice(trace, :x, 0.0, 0.1)
marcoct commented 4 years ago

@ali-ramadhan Thanks for proposing this, but I don't think that elliptical slice sampling is particularly useful for a univariate normal draw. Using mh(trace, select(:x)) should be just as effective in this case. Perhaps we can change the docstring to clarify this.

ali-ramadhan commented 4 years ago

Thanks for the advice. I'll try changing my model to sample from a mvnormal.

But surely elliptical_slice shouldn't error with a conversion error when drawing from a univariate normal?

If it's not meant to be used in such cases then maybe an ArgumentError should be thrown if length(mu) == 1 (and mentioned in the docstring)?

marcoct commented 4 years ago

Yes, we should throw a more informative error, thanks!

It looks like docstring does mention its for use with multivariate normal prior; we can add (mvnormal) to make this more explicit.