zenna / Omega.jl

Causal, Higher-Order, Probabilistic Programming
MIT License
164 stars 17 forks source link

Overview of proposed ciid changes #122

Closed zenna closed 3 years ago

zenna commented 5 years ago

This issue proposes a big change to the structure of Omega and the semantics of ciid. #120

Overall setup

Independence

Operationally, it works as follows. First, consider the simplest case, the i.i.d. case, and the following example:

a = 1 ~ unif
b = 2 ~ a

Conditional independence

For example, in the following snippet. m1 and m2 are conditionally independent given x, and m3 is completely independent from them both,

x = 1 ~ uniform(0, 19)
measure(ω) = x(ω) + uniform(0, 1)(ω)
m1 = 3 ~ measure <| (x,)
m2 = 4 ~ measure <| (x,)
m3 = 5 ~ measure

Operationally, conditional independence is possible with a single modification. To illustrate, in order for m1 to be ciid given x, the the evaluation of m1(ω) we should evaluate x(ω) in the context that x uses ordinarily, and not whatever context m1 has introduced. Hence, we simply to evaluate x(ω) with out any context, a new context will be reintroduced by x.

Prototype

module MiniOmega

using Cassette
import Base:~
export sample, unif, pointwise, <|, rt, Ω

const ID = NTuple{N, Int} where N

"Ω is a hypercube"
struct Ω
  data::Dict{ID, Real}
end

"Sample a random ω ∈ Ω"
sample(::Type{Ω}) = Ω(Dict{ID, Real}())
Base.getindex(ω::Ω, id) = get!(ω.data, id, rand())

sample(f) = f(sample(Ω))

unif(i::ID) = ω -> ω[@show(i)]

"Single primitive random variable"
unif(ω::Ω) = ω[(1,)]

# (Conditional) independence 

# Use cassette to augment enivonrment with extra state
Cassette.@context CIIDCtx 

"""Conditionally independent and identically distributed given `shared` parents

Returns the `id`th element of an (exchangeable) sequence of random variables
that are identically distributed with `f` but conditionally independent given
random variables in `shared`.

This is essentially constructs a plate where all variables in shared are shared,
and all other parents are not.
"""
function ciid(f, id::Integer, shared)
  let ctx = CIIDCtx(metadata = (shared = shared, id = (id,)))
    ω -> withctx(ctx, f, ω)
  end
end

withctx(ctx, f, ω) = Cassette.overdub(ctx, f, ω)
function Cassette.overdub(ctx::CIIDCtx, ::typeof(withctx), ctxinner, f, ω)
  # Merge the context
  id = (ctxinner.metadata.id..., ctx.metadata.id...)
  shared = (ctxinner.metadata.shared..., ctx.metadata.shared...)
  Cassette.overdub(CIIDCtx(metadata = (shared = shared, id = id)), f, ω)
end

"I.I.D.: `ciid` with nothing shar ed"
ciid(f, id::Integer) = ciid(f, id, ())

Cassette.overdub(ctx::CIIDCtx, ::typeof(unif), ω::Ω) =
  @show unif(ctx.metadata.id)(ω)

function Cassette.overdub(ctx::CIIDCtx, x, ω::Ω)
  if x in ctx.metadata.shared
    x(ω)
  else
    Cassette.recurse(ctx, x, ω)
  end
end

# Syntactic Sugar (to make model-building nicer)

"Random tuple"
rt(fs...) = ω -> map(f -> f(ω), fs)

"""Supports notation `i ~ x <| (y,z,)`
which is the ith element of an (exchangeable) sequence of random variables that are
identically distributed with x but conditionally independent given y and z.
"""
struct Plate{F, S}
  f::F
  shared::S
end

@inline <|(f, shared::Tuple) = Plate(f, shared)
~(id::Integer, f) = ciid(f, id)
~(id::Integer, plate::Plate) = ciid(plate.f, id, plate.shared)

# Pointwise
Cassette.@context PWCtx
Lifted = Union{map(typeof, (+, -, /, *))...}
Cassette.overdub(::PWCtx, op::Lifted, x::Function) = ω -> op(x(ω))
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y::Function) = ω -> op(x(ω), y(ω))
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y) = ω -> op(x(ω), y)
Cassette.overdub(::PWCtx, op::Lifted, x, y::Function) = ω -> op(x, y(ω))
pointwise(f) = Cassette.overdub(PWCtx(), f)

# AutoId
Cassette.@context AutoIdCtx

end

using .MiniOmega

function simpletest()
  a = 1 ~ unif
  b = 2 ~ a
  sample(b)
end

function test()
  uniform(a, b) = ω -> unif(ω) * (b - a) + a

  x = 1 ~ uniform(0, 1)
  y = 2 ~ uniform(0, 1)
  d = 7 ~ y
  z = ω -> (x(ω), x(ω), y(ω), d(ω))
  @show sample(z)
  function a(ω)
    x_ = x(ω)
    d = 3 ~ uniform(0, 4)
    e = 4 ~ uniform(0, 4)
    x_ + d(ω) + e(ω)
  end

  @show sample(rt(x, y, z, a))

  # Conditional Independence
  x = 1 ~ uniform(0, 19)
  measure(ω) = x(ω) + uniform(0, 1)(ω)
  m1 = 3 ~ measure <| (x,)
  m2 = 4 ~ measure <| (x,)
  m3 = 5 ~ measure

  @show sample(rt(m1, m2, m3))

  # Linaer regression 
  α = 1 ~ uniform(0, 1) 
  β = 2 ~ uniform(0, 1)
  f(x, i) = i ~ (ω -> x*α(ω) + β(ω) + uniform(0.0, 1.0)(ω)) <| (α, β)
  y1 = f(0.3, 3)
  y2 = f(0.3, 4)
  @show sample(rt(y1, y2))

  y1b, y2b = pointwise() do
    f2(x, i) = i ~ (x * α + β + uniform(0.0, 1.0)) <| (α, β)
    y1b = f2(0.3, 3)
    y2b = f2(0.3, 4)
    y1b, y2b
  end

  @show sample(rt(y1, y2, y1b, y2b))

  # A bunch of random variables with a shared parents
  pointwise() do
    x = 1 ~ unifQD
    y = x * 100
    function a(ω)
      y(ω) + 5
    end
    function b(ω)
      y(ω) + 10
    end
    c = 2 ~ b <| (y,)
    d = 3 ~ b
    @show sample(rt(y, a, b, c, d))
  end
end

Port to Omega?

There may be some small bugs (please check) in this approach but I'm fairly confident it is both mostly correct, more general, and simpler than what we are currently doing. So, I should probably port the change to real Omega.

A few concerns?

Overall, I'm inclined to say we support any object that "purely" implements (::Type)(ω::Ω). We can still use typed and/or id'd objects for the above reasons, but not if we don't need that.

x = normal(0, 1)
f_(ω) = uniform(ω, 0, 1) + 10 * 2 + x(ω)
f = ciid(f)   # unexpected! x in f will be independent normal

This code would break, or more precisely, do something different to what it currently does. The reason is that currently this code would share the value of x, but ciid (aka ~) without any shared parameters would create an i.i.d. variable. Fortunately though, because we no longer have a distinction between random variables and functions, this does not mean we always need to explicitly stat what we want the parents to be, we in fact do not need the ciid at all in this example

x = normal(0, 1)
f(ω) = 1 ~ uniform(ω, 0, 1) + 10 * 2 + x(ω) # just use f! x will be a parent in f and in g below
g(ω) = 2 ~ normal(ω, 0, 1) + 2 + x(ω)
zenna commented 5 years ago

A quick primer on Cassette. Mostly for @jburroni

Cassette allows you to implement some (actually many) kinds of non-standard interpretations.

The general approach is to first create a context, and then add new definitions to that context.

using Cassette
Cassette.@context MyCtx   # Creates a new context type called MyCtx

In this new context I will replace sin with cos

# adding methods to Cassette.overdub is how to define new semantics
function Cassette.overdub(::MyCtx, ::typeof(sin), x)
  println("I will now do cos instead of sin")
  cos(x)
end

Try it out

myprogram() = sin(3)

# execute myprogram in MyCtx
Cassette.overdub(MyCtx(), myprogram)

One thing it quite easily allows you to do something equivalent to adding extra state to the environment, in the operational semantics sense. This is done by adding data ('metadata') to your context. In this example I'll add metadata which just counts the number of function calls

mutable struct Counter
  count::Int
end

# This will intercept __every__ function call (except one's in Base that Cassette has defaults to not recurse into)
function Cassette.overdub(ctx::MyCtx, f, args...)
  ctx.metadata.count += 1
  # Now Continue running 
  Cassette.recurse(ctx, f, args...)
end

# Let's try it out
ctx = MyCtx(metadata = Counter(1))
Cassette.overdub(ctx, myprogram)
println("Number of function calls is $(ctx.metadata)")
zenna commented 3 years ago

Better in discussions