Open yebai opened 1 month ago
Hi @yebai !
I think minibatching will require a more invasive solution than a syntax akin to condition
. For instance, if the model doesn't have any latent variables/local variables like
@model function logistic(X, y)
θ ~ MvNormal(Zeros(d), I)
y .~ BernoulliLogit(X*θ)
end
Then we could levereage condition
as model | (X=X[batch,:], y=y[batch])
and swap the context to be MiniBatchContext
.
The problem is models that do have latent variables like matrix factorization:
@model function nmf(k, y)
m, n = size(y)
items ~ filldist(Gamma(1, 1), m, k)
users ~ filldist(Gamma(1,1), k, n)
Λ = items*users
@. y ~ Poisson(Λ)
end
Let's say we want to subsample over the users, which would correspond to the columns of y
. Unfortunately, we cannot simply slice y
since the dimensionality of users
also has to change. So, I think we do need some sort of functionality to express this in the model.
The way Pyro
does this, which uses pyro.plate
, is by creating an index range. In Turing
, I think it would look something like
@model function nmf(k, y)
m, n = size(y)
items ~ filldist(Gamma(1, 1), m, k)
users ~ filldist(Gamma(1, 1), k, n)
idx = @dataindex(1:n)
Λ = items*users[:,idx]
ysub = y[:,idx]
@. ysub ~ Poisson(Λ)
end
When subsampling, the backend would modify the output of @dataindex
such that the indices are subsampled over the range 1:n
. Ideally, we could also evaluate the prior density of users
over the subsampled latent variables, but this would make things a little more complicated since part
of the prior density will need to be adjusted, not just the likelihood.
On a related note, to properly support subsampling, DynamicPPL
will also need to split addlogprob!
into addlogprior!
and addloglike!
so that only the likelihood is adjusted.
So long story short, I think a new syntax will be needed to properly support subsampling.
@torfjelde any thoughts?
The way Pyro does this, which uses pyro.plate, is by creating an index range. In Turing, I think it would look something like
We can treat the idx
variable as a special argument, i.e.,
@model function nmf(k, y; idx = 1:n) # inference algorithm can override the default `idx=1:n`.
m, n = size(y)
items ~ filldist(Gamma(1, 1), m, k)
users ~ filldist(Gamma(1, 1), k, n)
Λ = items*users[:,idx]
ysub = y[:,idx]
@. ysub ~ Poisson(Λ)
end
Do you think that's enough?
On a related note, to properly support subsampling, DynamicPPL will also need to split addlogprob! into addlogprior! and addloglike! so that only the likelihood is adjusted.
IIUC, we won't need splitting addlogprob!
, since the tilde pipeline knows whether LHS is observed. The adjustment can be handled at the tilde pipeline level instead of inside addogprob!
.
We can treat the idx variable as a special argument, i.e.,
Oh yes, I think that would actually work for now if we assume subsampling over the prior is not supported.
Though, for generality, it would probably be better to restruct idx
to be an iterable rather than an index so that one could splice the data beforehand. Then the backend would need:
MinibatchContext
with appropriate likelihood adjustment.IIUC, we won't need splitting addlogprob!, since the tilde pipeline knows whether LHS is observed. The adjustment can be handled at the tilde pipeline level instead of inside addogprob!.
I am not sure what you mean here by the tilde pipeline. I was thinking of the case where addlogprob!
is used not just for incrementing the likelihood but also the prior. In that case, there is no way to differentiate which is which; is not the case?
Though, for generality, it would probably be better to restruct idx to be an iterable rather than an index so that one could splice the data beforehand.
For clarity (sorry for lack of details above), all the operations can be wrapped in a minibatched = minibatch(conditioned_model, N, batch_size)
API, then standard logdensity
functions can be called on minibatched
to compute (stochastic) log density and gradients. A more concise design is to combine condition
and minibatch
into a minibatched condition, e.g., minibatched = condition(model, Minibatch((data = Y), N, batch_size))
Yes how to do all this is pretty clear to me now! Just in terms of the details, though, I was thinking more like
sub_iter = sample(subsample_range, B, replace=false)
minibatched = minibatch(model, (sub_iter=sub_iter,))
where minibatch
would take length(sub_iter)
to apply MinibatchContext
. So this wouldn't necessarily involve indices.
We have some interesting operators on DynamicPPL models, such as
condition
/decondition
,fix
, andgenerated_quantities
. The advantage of these operators is that models can be specified without knowing them, which is in line with the broad principle of separating modelling and inference specification.https://turinglang.org/DynamicPPL.jl/stable/api/#AbstractPPL.condition
The operators mentioned above prompt me to wonder whether we can introduce operators like
minibatch
/stocchastic_gradient
on models involving a loop over IID data points. These operators would throw an error if the input model does not contain IID data points but would return a new (minibatched
) model if it does.cc @Red-Portal, who will find this useful for stochastic VI.