marius311 / MuseInference.jl

Fast approximate high-dimensional hierarchical Bayesian inference
https://cosmicmar.com/MuseInference.jl
MIT License
32 stars 4 forks source link

Soss.jl interface #4

Closed marius311 closed 2 years ago

marius311 commented 2 years ago

Basic version working:

using Soss, Distributions, MuseInference, Zygote, DynamicHMC

funnel = @model () begin
    θ ~ Normal(0, 3)
    z ~ MvNormal(zeros(512), exp(θ/2))
    x ~ MvNormal(z, 1)
end

(;x,) = predict(funnel, θ=0)
prob = SossMuseProblem(funnel | (;x), autodiff=MuseInference.ZygoteBackend())
result = @time muse(prob, (θ=0,), get_covariance=true)

and its 10X faster than Turing :tada: Also this was awesome easy to write and is very few lines of code.

Remaining TODO:

cscherrer commented 2 years ago

I'd probably write this as

julia> using MeasureTheory

julia> funnel = @model σ begin
           θ ~ Normal(σ=σ)
           z ~ Normal(σ=exp(θ/2)) ^ 4
           x ~ For(z) do zj
               Normal(μ=zj)
               end
           end;

(4 instead of 512 so it's easier to display) Then for a little demo,

julia> truth = rand(funnel(3))
(θ = -0.0379977, z = [-0.864762, 1.39213, -0.800246, 0.0144543], x = [-1.41632, 1.65007, -0.875127, 1.05682])

julia> obs = (x = truth.x,)
(x = [-1.41632, 1.65007, -0.875127, 1.05682],)

julia> using SampleChainsDynamicHMC

julia> post = funnel(3) | obs
ConditionalModel given
    arguments    (:σ,)
    observations (:x,)
@model σ begin
        θ ~ Normal(σ = σ)
        z ~ Normal(σ = exp(θ / 2)) ^ 4
        x ~ For(z) do zj
                Normal(μ = zj)
            end
    end

julia> s = sample(post, dynamichmc())
4000-element MultiChain with 4 chains and schema (θ = Float64, z = Vector{Float64})
(θ = -1.2±1.8, z = [-0.47±0.69, 0.54±0.74, -0.29±0.61, 0.33±0.64])

For transformations, we use TransformVariables.jl. You can get the transform for a given model using xform, like this:

julia> import TransformVariables as TV

julia> tr = xform(post)
TransformTuple{NamedTuple{(:θ, :z), Tuple{Identity, ArrayTransform{Identity, 1}}}}((θ = asℝ, z = ArrayTransform{Identity, 1}(asℝ, (4,))), 5)

julia> TV.dimension(tr)
5

julia> TV.transform(tr, randn(5))
(θ = -0.462708, z = [1.17351, -0.631122, -0.637643, 1.28246])
cscherrer commented 2 years ago

Also, I want to extend this functionality some more, but for now you can do

julia> prior(funnel, :x)
@model σ begin
        θ ~ Normal(σ = σ)
        z ~ Normal(σ = exp(θ / 2)) ^ 4
    end

julia> likelihood(funnel, :x)
@model z begin
        x ~ For(z) do zj
                Normal(μ = zj)
            end
    end
marius311 commented 2 years ago

Thanks for the replies! Couple followups.

With the model defined as above with MeasureTheory it looks like Zygote gradients are failing. Prevoiusly I had only looked at Distributions and there both Zygote/ForwardDiff work. Is this a known issue and/or is my call OK?

(;x,z,θ) = rand(funnel(3))
logdensity(funnel(3) | (;x, z), (;θ)) # ok 
ForwardDiff.derivative(θ -> logdensity(funnel(3) | (;x, z), (;θ)), θ) # ok
Zygote.gradient(θ -> logdensity(funnel(3) | (;x, z), (;θ)), θ)[1] 
# ERROR: Need an adjoint for constructor MappedArrays.ReadonlyMappedArray

Btw I had noticed that with Distributions, Zygote.gradient(θ -> logdensity(funnel(3) | (;x, z), (;θ)), θ) is slightly better performance than Zygote.gradient(θ -> logdensity(funnel(3), (;θ, x, z)), θ) since the former seems to entirely drop the gradients w.r.t. (x,z), which here is what I want, so that's why I did it that way in this interface. Is that OK?

Re the likelihood/prior thing, yea those are super useful. For MUSE the prior I need is P(θ) so I ended up with

https://github.com/marius311/MuseInference.jl/blob/55fe15f50062c34a90b7b54aa512f28eabaebf64/src/soss.jl#L26

which from what I can tell propagates the argvals too should the prior depend on them.

And thanks for the HMC code and transform code, that's helpful!

cscherrer commented 2 years ago

Thanks, I hadn't tried Zygote in a while. Maybe some context will help for this.

For builds a product measure based on some parameterization. It's not itself an array, but you can describe it in terms of an array of measures. Like this:

julia> d = For(randn(3)) do μ Normal(μ=μ) end
For{Normal{(:μ,), Tuple{Float64}}}(#7, [0.1560796508305534, 0.6416014978488057, 0.3847545293905175])

julia> marginals(d)
3-element mappedarray(#7, ::Vector{Float64}) with eltype Normal{(:μ,), Tuple{Float64}}:
 Normal(μ = 0.15608,)
 Normal(μ = 0.641601,)
 Normal(μ = 0.384755,)

Those are the same as if you had used map:

julia> map(d.f, d.inds...)
3-element Vector{Normal{(:μ,), Tuple{Float64}}}:
 Normal(μ = 0.15608,)
 Normal(μ = 0.641601,)
 Normal(μ = 0.384755,)

But map requires an allocation, which we want to avoid. So there are some cases where we delegate to MappedArrays.jl.

But it looks like MappedArrays is missing ChainRules to let Zygote do its thing. But it apparently does work ok in the dev branch of MeasureTheory, which will be the basis for the next Soss release:

julia> x = randn(3)
3-element Vector{Float64}:
 -0.930351
 -1.24049
  0.355545

julia> f(β) = For(x) do μ Normal(μ=β + μ) end
f

julia> Zygote.gradient(β -> logdensityof(f(β), x), 0.3)
(-0.9,)

Btw I had noticed that with Distributions, Zygote.gradient(θ -> logdensity(funnel(3) | (;x, z), (;θ)), θ) is slightly better performance than Zygote.gradient(θ -> logdensity(funnel(3), (;θ, x, z)), θ) since the former seems to entirely drop the gradients w.r.t. (x,z), which here is what I want, so that's why I did it that way in this interface. Is that OK?

If I'm thinking of it right, that should just give you the partials, so it should be fine. Maybe worth adding a test just to be safe?

BTW, I'll need to figure out how to square Soss.likelihood with MeasureTheory.likelihood, which works like this:

julia> f(β) = For(x) do μ Normal(μ=β + μ) end
f

julia> post = Lebesgue() ⊙ likelihood(f, x)
Lebesgue(ℝ) ⊙ Likelihood(f, [-0.9303505027587362, -1.2404857664253213, 0.3555452843296289])

julia> logdensityof(post, 0.3)
-2.89182

# compare
julia> logdensityof(f(0.3), x)
-2.89182

I'm setting things up this way because for some cases like linear models, the likelihood can be evaluated a lot more efficiently than the naive approach.

marius311 commented 2 years ago

All basically working. Here's my benchmark if you're curious image

This is with the model written as

Soss.@model () begin
    θ ~ Normal(0, 3)
    z ~ MvNormal(zeros(512), exp(θ/2))
    x ~ MvNormal(z, 1)
end

rather than your way since I want to use Zygote (which is faster than ForwardDiff here). All the codes are set up to run the same MUSE solution, so its basically benchmarking a posterior gradient evaluation. Full code here: https://gist.github.com/marius311/295164a0bbfbdb281a2dd6fb473597b6

Will probably merge shortly, but of course happy to iterate on stuff pending Soss changes or any comments you might have.

cscherrer commented 2 years ago

Wow! This is great! I haven't done much performance tuning, so I'm surprised to see it's doing so well. At the same time, I think as MeasureTheory progresses we should be able to get close to that "Julia" time