zenna / Omega.jl

Causal, Higher-Order, Probabilistic Programming
MIT License
162 stars 17 forks source link

Conditioning on the multivariate #198

Open ga72kud opened 2 years ago

ga72kud commented 2 years ago

In the following code, I condition on the multivariate normal distribution x. The mean of the multivariate normal distribution is dependent on a univariate distribution x₁=normal(-4.,2.). I am afraid the result of the condx=cond(x, x₁>=-0.5) conditioning on the multivariate distribution (x and x_1 are different distributions!) have false results.

using StatsPlots
using Distributions
using Omega
using Random
using Flux

AM_SAMPLES=150
plot(layout=(2,2), aspect_ratio=:equal, axes=:equal)
x₁=normal(-4.,2.)
samples_x₁=[rand(x₁) for i in 1:AM_SAMPLES]
display(histogram!(samples_x₁, lab="p(x₁)", subplot=4))
condx1=cond(x₁, x₁>=-3)
samples_condx₁=[rand(condx1, alg=RejectionSample) for i in 1:AM_SAMPLES]
display(histogram!(samples_condx₁, lab="p(x₁|x₁>=-3)", subplot=4))
x₂=normal(-4.,1.)
function x_(rng)
    rand(MvNormal([x₁(rng);x₂(rng)], [1.0 0.0;0.0 1.0]))
end

#
#Initial Distribution
#
x = ciid(x_)
samples=[rand(x) for i=1:AM_SAMPLES]
samples=permutedims(hcat(samples...))
display(scatter!(samples[:,1], samples[:,2], subplot=1, lab="", alpha=.4))
title!("Initial distribution", subplot=1)
display(histogram2d!(samples[:,1], samples[:,2], nbins = 20, subplot=1, alpha=.7))
#
#Conditional Distribution
#
condx=cond(x, x₁>=-0.5)
samples_condx=[rand(condx, alg=RejectionSample) for i=1:AM_SAMPLES]
samples_condx=permutedims(hcat(samples_condx...))
display(scatter!(samples_condx[:,1], samples_condx[:,2], subplot=2, lab="", alpha=.4))
title!("Conditional distribution", subplot=2)
display(histogram2d!(samples_condx[:,1], samples_condx[:,2], nbins = 20, subplot=2, alpha=.7))
ga72kud commented 2 years ago

I tried also something like this, but it leads to a probability condx, where I cannot use rand()

x = MvNormal(collect(rand((x₁,x₂))), [1.0 0.0;0.0 1.0])
...
#
#Conditional Distribution
#
condx=cond(x, x.μ[1]>=-0.5)

I get the error:

MethodError: no method matching apl(::Bool, ::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:err, :rng), Tuple{Base.RefValue{Real}, Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}) Closest candidates are: apl(!Matched::RandVar, ::Omega.Space.TaggedΩ{I, T, ΩT}) where {I, T, ΩT<:Omega.Space.ΩProj} at /home/michael/.julia/packages/Omega/SOWrW/src/nondet/randvarapply.jl:53 condf(tω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:err, :rng), Tuple{Base.RefValue{Real}, Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}, x::FullNormal, y::Bool) at trackerror.jl:28 (::Omega.Cond.var"#1#2"{FullNormal, Bool})(ω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:err, :rng), Tuple{Base.RefValue{Real}, Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}) at pred.jl:13 ppapl(rv::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, ωπ::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:err, :rng), Tuple{Base.RefValue{Real}, Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}) at urandvar.jl:24 macro expansion at randvarapply.jl:58 [inlined] apl(rv::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, tω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:err, :rng), Tuple{Base.RefValue{Real}, Random._GLOBAL_RNG}}, LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}}) at randvarapply.jl:58 apl at randvarapply.jl:53 [inlined] applytrackerr(x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, ω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}, errinit::Bool) at trackerror.jl:11 indomain(x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, ω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}, errinit::Bool) at trackerror.jl:23 indomain(x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, ω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}) at trackerror.jl:23 (::Omega.Soft.var"#8#9"{Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}})(ω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, Omega.Space.ΩProj{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}, Vector{UInt64}}}) at trackerror.jl:24 ppapl at urandvar.jl:24 [inlined] macro expansion at randvarapply.jl:58 [inlined] apl(rv::Omega.NonDet.URandVar{Omega.Soft.var"#8#9"{Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}}, Tuple{}}, tω::Omega.Space.TaggedΩ{Vector{UInt64}, NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}}) at randvarapply.jl:58 rand at fail.jl:21 [inlined] rand(rng::Random._GLOBAL_RNG, x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, n::Int64, alg::Omega.Inference.FailUnsatAlg; ΩT::Type{LinearΩ{Vector{UInt64}, UnitRange{Int64}, Vector{Any}}}) at fail.jl:35 (::Base.var"#rand##kw")(::NamedTuple{(:ΩT,), Tuple{DataType}}, ::typeof(rand), rng::Random._GLOBAL_RNG, x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}, n::Int64, alg::Omega.Inference.FailUnsatAlg) at fail.jl:34 rand(x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}; alg::Omega.Inference.FailUnsatAlg, ΩT::Type, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) at rand.jl:12 rand(x::Omega.NonDet.URandVar{Omega.Cond.var"#1#2"{FullNormal, Bool}, Tuple{}}) at rand.jl:12 top-level scope at 15_single_multivariate_causal_inference.jl:30 eval at boot.jl:360 [inlined]

ga72kud commented 2 years ago

This is a bit shorter version. I am afraid, but I think I do not get the right result. Should I try another algorithm for the sampling?

using StatsPlots
using Distributions
using Omega
using Random
using Flux
rng=Random.GLOBAL_RNG
AM_SAMPLES=150
plot(layout=(2,2), aspect_ratio=:equal, axes=:equal)

#
#Initial Distribution
#
x =~ rng -> rand(MvNormal([x₁(rng); x₂(rng)], [1.0 0.0;0.0 1.0]))
samples=[rand(x) for i=1:AM_SAMPLES]
samples=permutedims(hcat(samples...))
display(scatter!(samples[:,1], samples[:,2], subplot=1, lab="", alpha=.4))
title!("Initial distribution", subplot=1)
display(histogram2d!(samples[:,1], samples[:,2], nbins = 20, subplot=1, alpha=.7))
#
#Conditional Distribution
#
condx=cond(x, x₁>=-2.)
samples_condx=[rand(condx, alg=RejectionSample) for i=1:AM_SAMPLES]
samples_condx=permutedims(hcat(samples_condx...))
display(scatter!(samples_condx[:,1], samples_condx[:,2], subplot=2, lab="", alpha=.4))
title!("Conditional distribution", subplot=2)
display(histogram2d!(samples_condx[:,1], samples_condx[:,2], nbins = 20, subplot=2, alpha=.7))

Update: I think I made an error, I don't want to conditioning on the domain of the multivariate distribution but rather on the output: Something like this, which does not work:

outMVN=MvNormal([2.,2.], [1.0 0.0; 0.0 1.0])
condMVN=cond(outMVN, rand(outMVN)[1]>2.0)
rand(condMVN)
ga72kud commented 2 years ago

I tried to figure it out. I think I can reduce the question (how to run that rand(a)):

f(x, y)=~MvNormal(μ, Σ)<=[1.0 0.0; 0.0 0.0]*[x;y]
a=f(2.,0.)
rand(a)

a slightly different approach:

μ=[0.0;0.0]
Σ=[2.0 0.0;0.0 2.0]
A=[1.0 0.0; 0.0 0.0]
function myf(A, x, y)
a=rand(MvNormal(μ, Σ))
if(a<=A*[x;y])
    return a
else
    [NaN, NaN]
end
end

samples=[myf(0.,0.) for i=1:AM_SAMPLES]
samples=permutedims(hcat(samples...))
display(scatter!(samples[:,1], samples[:,2], subplot=1, lab="", alpha=.4))
title!("Initial distribution", subplot=1)
display(histogram2d!(samples[:,1], samples[:,2], nbins = 20, subplot=1, alpha=.4))