zenna / Omega.jl

Causal, Higher-Order, Probabilistic Programming
MIT License
163 stars 17 forks source link

Intervention on multivariate normal distribution #197

Open ga72kud opened 2 years ago

ga72kud commented 2 years ago

I am trying to use following univariate distribution case to transform it into a multivariate distribution case:

α=categorical([0.3, 0.7])
function x_(rng)
  #α=categorical([0.1, .9])
  if(α(rng)==1)
    x=normal(rng, 0.3, .1)
  elseif(α(rng)==2)
    x=normal(rng, -0.3, .1)
  elseif(α(rng)==3)
    x=normal(rng, 0.0, .01)
  end
end
AM_SAMPLES=10000
x = ciid(x_)
samples=[rand(x) for i=1:AM_SAMPLES]
display(histogram!(samples, subplot=1))

x_interv=replace(x, α=>categorical([0.7, 0.2, 0.1]))

I am wondering how to use mvnormal (Omega multivariate distributions do not work for me). I used instead the Distributions.jl MvNormal function.

α=categorical([0.3, 0.7])
function x_(rng)
  if(α(rng)==1)
    rand(MvNormal([-2.0;-4.0], [1.0 0.0;0.0 1.0]))
  elseif(α(rng)==2)
    rand(MvNormal([2.0;-1.0], [1.0 0.0;0.0 1.0]))
  else
    rand(MvNormal([3.0;6.0], [1.0 0.0;0.0 1.0]))
  end
end
AM_SAMPLES=500
x = ciid(x_)
samples=[rand(x) for i=1:AM_SAMPLES]

This example worked for me. But I am wondering why I need rand(...) inside the function x_