TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
https://turinglang.org/DynamicPPL.jl/
MIT License
156 stars 26 forks source link

`minibatch` / `stochastic_gradient` operation on models #633

Open yebai opened 1 month ago

yebai commented 1 month ago

We have some interesting operators on DynamicPPL models, such as condition / decondition, fix, and generated_quantities. The advantage of these operators is that models can be specified without knowing them, which is in line with the broad principle of separating modelling and inference specification.

https://turinglang.org/DynamicPPL.jl/stable/api/#AbstractPPL.condition

The operators mentioned above prompt me to wonder whether we can introduce operators like minibatch / stocchastic_gradient on models involving a loop over IID data points. These operators would throw an error if the input model does not contain IID data points but would return a new (minibatched) model if it does.

cc @Red-Portal, who will find this useful for stochastic VI.

Red-Portal commented 1 month ago

Hi @yebai !

I think minibatching will require a more invasive solution than a syntax akin to condition. For instance, if the model doesn't have any latent variables/local variables like

@model function logistic(X, y)
    θ ~ MvNormal(Zeros(d), I)
    y .~ BernoulliLogit(X*θ)
end

Then we could levereage condition as model | (X=X[batch,:], y=y[batch]) and swap the context to be MiniBatchContext.

The problem is models that do have latent variables like matrix factorization:

@model function nmf(k, y)
    m, n = size(y)
    items ~ filldist(Gamma(1, 1), m, k)
    users ~ filldist(Gamma(1,1), k, n)
    Λ = items*users
    @. y ~ Poisson(Λ)
end

Let's say we want to subsample over the users, which would correspond to the columns of y. Unfortunately, we cannot simply slice y since the dimensionality of users also has to change. So, I think we do need some sort of functionality to express this in the model.

The way Pyro does this, which uses pyro.plate, is by creating an index range. In Turing, I think it would look something like

@model function nmf(k, y)
    m, n = size(y)
    items ~ filldist(Gamma(1, 1), m, k)
    users ~ filldist(Gamma(1, 1), k, n)

    idx = @dataindex(1:n)
    Λ = items*users[:,idx]
    ysub = y[:,idx] 
    @. ysub ~ Poisson(Λ)
end

When subsampling, the backend would modify the output of @dataindex such that the indices are subsampled over the range 1:n. Ideally, we could also evaluate the prior density of users over the subsampled latent variables, but this would make things a little more complicated since part of the prior density will need to be adjusted, not just the likelihood.

On a related note, to properly support subsampling, DynamicPPL will also need to split addlogprob! into addlogprior! and addloglike! so that only the likelihood is adjusted.

So long story short, I think a new syntax will be needed to properly support subsampling.

@torfjelde any thoughts?

yebai commented 1 month ago

The way Pyro does this, which uses pyro.plate, is by creating an index range. In Turing, I think it would look something like

We can treat the idx variable as a special argument, i.e.,

@model function nmf(k, y; idx = 1:n) # inference algorithm can override the default `idx=1:n`. 
    m, n = size(y)
    items ~ filldist(Gamma(1, 1), m, k)
    users ~ filldist(Gamma(1, 1), k, n)

    Λ = items*users[:,idx]
    ysub = y[:,idx] 
    @. ysub ~ Poisson(Λ)
end

Do you think that's enough?

On a related note, to properly support subsampling, DynamicPPL will also need to split addlogprob! into addlogprior! and addloglike! so that only the likelihood is adjusted.

IIUC, we won't need splitting addlogprob!, since the tilde pipeline knows whether LHS is observed. The adjustment can be handled at the tilde pipeline level instead of inside addogprob!.

Red-Portal commented 1 month ago

We can treat the idx variable as a special argument, i.e.,

Oh yes, I think that would actually work for now if we assume subsampling over the prior is not supported.

Though, for generality, it would probably be better to restruct idx to be an iterable rather than an index so that one could splice the data beforehand. Then the backend would need:

  1. An official API function that provides the full iterable of a model.
  2. A function that conditions the model with the subsampled range and changes the evaluation context to be MinibatchContext with appropriate likelihood adjustment.

IIUC, we won't need splitting addlogprob!, since the tilde pipeline knows whether LHS is observed. The adjustment can be handled at the tilde pipeline level instead of inside addogprob!.

I am not sure what you mean here by the tilde pipeline. I was thinking of the case where addlogprob! is used not just for incrementing the likelihood but also the prior. In that case, there is no way to differentiate which is which; is not the case?

yebai commented 1 month ago

Though, for generality, it would probably be better to restruct idx to be an iterable rather than an index so that one could splice the data beforehand.

For clarity (sorry for lack of details above), all the operations can be wrapped in a minibatched = minibatch(conditioned_model, N, batch_size) API, then standard logdensity functions can be called on minibatched to compute (stochastic) log density and gradients. A more concise design is to combine condition and minibatch into a minibatched condition, e.g., minibatched = condition(model, Minibatch((data = Y), N, batch_size))

Red-Portal commented 1 month ago

Yes how to do all this is pretty clear to me now! Just in terms of the details, though, I was thinking more like

sub_iter = sample(subsample_range, B, replace=false)
minibatched = minibatch(model, (sub_iter=sub_iter,))

where minibatch would take length(sub_iter) to apply MinibatchContext. So this wouldn't necessarily involve indices.