cscherrer / Tilde.jl

WIP successor to Soss.jl
MIT License
74 stars 1 forks source link

Unbiased Tilde-models #23

Closed mschauer closed 2 years ago

mschauer commented 2 years ago

Below I create two logistic regression tilde models, full_model and hat_model. Note that the hat_model is stochastic and its gradient is an unbiased estimate of the full model, scaled by K/N, because each observation is picked with probability K/N. You see that I use a trick for the global parameters which are picked alway (with probability 1) because α ~ Tilde.Normal(0,sqrt(N1/K)) gives a log-likelihood which, when scaled together with the data, goes back to α ~ Tilde.Normal(0,1).

full_model = @Tilde.model N1, C, c, y begin
    α ~ Tilde.Normal(0,1)
    cc ~ Tilde.Normal(0,1)^C
    for i in 1:N1
        v = α + cc[c[i]]
        y[i] ~ Soss.Bernoulli(logistic(v))
    end
end

hat_model = @Tilde.model N1, C, c, K, seed, y begin
    α ~ Tilde.Normal(0,sqrt(N1/K))
    cc ~ Tilde.Normal(0,sqrt(N1/K))^C
    sampler = Random.SamplerRangeNDL(1:N1)
    rng = ZZB.Rng(seed)
    for _ in 1:K
        i = rand(rng, sampler)
        v = α + cc[c[i]]
        y[i] ~ Soss.Bernoulli(logistic(v))
    end
end

It would be nice to tell Tilde to put a weight/factor on the log-likelihood, like this:

hat_model = @Tilde.model N1, C, c, K, seed, y begin
    α ~ Tilde.Normal(0,1)
    cc ~ Tilde.Normal(0,1)^C
    sampler = Random.SamplerRangeNDL(1:N1)
    rng = ZZB.Rng(seed)
    for _ in 1:K
        i = rand(rng, sampler)
        v = α + cc[c[i]]
        y[i] ~ Soss.Bernoulli(logistic(v), weight=N/K)
    end
end
cscherrer commented 2 years ago

Does this do what you want? https://github.com/cscherrer/MeasureBase.jl/blob/master/src/combinators/powerweighted.jl

So I think you'd write

y[i] ~ Soss.Bernoulli(logistic(v)) ↑ (N/K)

or (maybe compare performance on these)

y[i] ~ Soss.Bernoulli(logitp = v) ↑ (N/K)

BTW I don't think you need the Soss. and Tilde., since Normal and Bernoulli are from MeasureTheory

mschauer commented 2 years ago

Awesome!