Open seabbs opened 1 month ago
For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.
Are there inspiring examples from other PPLs? Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing
is not seeing that as something to offer.
One consideration here is that is easier to parallelise the accumulation of log-posterior density than it is to parallelise the full ~
action in Turing. In my opinion, this is an argument for having an "inference mode" in Turing
as discussed https://github.com/TuringLang/DynamicPPL.jl/issues/510
In terms of nice packages in wider Julia, Floops.jl
has nice functionality here, and you can at least in Forward mode get it to run with Turing as per this Pluto nb. Note that to get this to work I had to use @addlogprob!
....
For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.
So yes but the only live example is restricted to having observations on the left-hand side. This would have very limited utility for us - especially if it only supports forward diff. As the original comment points out we have somee very clear uses cases where multithreading would be useful (namely dispatching on submodel in a for loop - this would closely mirror how reduce_sum is used in stan if limited to only submodels that don't produce output used by other parts of the model (i.e StackObservationModels
would be this.
Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing is not seeing that as something to offer.
I'm struggling to see support for this point given there are many obvious cases where multi-threading is what makes it possible to run very large models in any reasonable time. Given the current relatively bad performance of Turing
and the relative ease of supporting the julia parallel ecosystem this seems relatively higher priority here than in other PPLs (to oft set slowness).
I see your point about an inference mode but I think slightly off topic.
reduce_sum
from stan: https://mc-stan.org/docs/stan-users-guide/parallelization.html#reduce-sum
does a lot of heavy lifting for you.
numpyro
has approaches for this via Jax
some of which are automated and some of which are not. No example pinned down though. It has the explicit plate
context which seems like reduce_sum
in stan but I have seen nothing that suggests it does in fact work in parallel.
The reduce_sum
approach is the one thats easy to implement directly? In what you linked stan is offering the user the chance to rewrite their code so that conditionally independent log post density sums can get parallelised whilst killing the syntactic sugar of the stan language... I don't see whats different to using Floops
here?
e.g they convert:
data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
beta ~ std_normal();
y ~ bernoulli_logit(beta[1] + beta[2] * x);
}
into
functions {
real partial_sum(array[] int y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
int grainsize = 1;
beta ~ std_normal();
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
}
Whereas in the normal Turing -> Floops example (albeit for MvNormal)
@model function test_model(n::Integer)
μ ~ Normal(1.0, 1.0)
σ ~ truncated(Normal(0., 1.), 0, Inf)
y ~ MvNormal(μ * ones(n), σ * ones(n))
end
goes to
@model function test_model_floops(y, n::Integer)
μ ~ Normal(1.0, 1.0)
σ ~ truncated(Normal(0., 1.), 0, Inf)
lls = Vector{eltype(μ)}(undef, n)
let v = [μ, σ]
@floop for i in axes(y)
lls[i] = logpdf(Normal(v[1], v[2]), y[i])
end
end
Turing.@addlogprob! sum(lls)
end
Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?
numpyro
has approaches for this viaJax
some of which are automated and some of which are not. No example pinned down though.
@dylanhmorris , do you have any cool examples here?
Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?
This example is extremely limited vs what you can do in stan (i.e https://github.com/epinowcast/epinowcast/blob/886b45cb4bc5f338fa53d22e83d25335c55b1a4a/inst/stan/epinowcast.stan#L400 which runs all of the complicated obs model in parallel -i.e effectively dispatching over submodels).
I don't think our current approach works naturally with rephashsing everything as obs ~ lots of stuff
hence the suggestion that we want to target being able to do this over submodels.
@SamuelBrand1 and I had a f2f with the conclusion being that in the first instance we are aiming to check if @submodel
can be dispatched in parallel.
It is possible that
Turing.jl
already supports running submodels in parallel at least with some backends. It seems possible that if they don't now this functionality may be added if there is user demand.StackObservationModels
is a clear target for us as it is naively parallel and anything that is being stacked is likely to take sufficient compute to make this worthwhile. It is also likely that we will use this pattern again (i.e for multiple renewal processes) and so solving this here would solve it more widely in the tooling.See https://discourse.julialang.org/t/within-chain-parallelization-with-turing-jl/103402/4 for some discussion around current within-chain support (TLDR: Limited to calls where the LHS is fixed (i.e observations)).