CDCgov / Rt-without-renewal

https://cdcgov.github.io/Rt-without-renewal/
Apache License 2.0
12 stars 2 forks source link

Explore running StackObservationModels in parallel #254

Open seabbs opened 1 month ago

seabbs commented 1 month ago

It is possible that Turing.jl already supports running submodels in parallel at least with some backends. It seems possible that if they don't now this functionality may be added if there is user demand.

StackObservationModels is a clear target for us as it is naively parallel and anything that is being stacked is likely to take sufficient compute to make this worthwhile. It is also likely that we will use this pattern again (i.e for multiple renewal processes) and so solving this here would solve it more widely in the tooling.

See https://discourse.julialang.org/t/within-chain-parallelization-with-turing-jl/103402/4 for some discussion around current within-chain support (TLDR: Limited to calls where the LHS is fixed (i.e observations)).

SamuelBrand1 commented 1 month ago

For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.

Are there inspiring examples from other PPLs? Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing is not seeing that as something to offer.

One consideration here is that is easier to parallelise the accumulation of log-posterior density than it is to parallelise the full ~ action in Turing. In my opinion, this is an argument for having an "inference mode" in Turing as discussed https://github.com/TuringLang/DynamicPPL.jl/issues/510

In terms of nice packages in wider Julia, Floops.jl has nice functionality here, and you can at least in Forward mode get it to run with Turing as per this Pluto nb. Note that to get this to work I had to use @addlogprob!....

seabbs commented 1 month ago

For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.

So yes but the only live example is restricted to having observations on the left-hand side. This would have very limited utility for us - especially if it only supports forward diff. As the original comment points out we have somee very clear uses cases where multithreading would be useful (namely dispatching on submodel in a for loop - this would closely mirror how reduce_sum is used in stan if limited to only submodels that don't produce output used by other parts of the model (i.e StackObservationModels would be this.

Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing is not seeing that as something to offer.

I'm struggling to see support for this point given there are many obvious cases where multi-threading is what makes it possible to run very large models in any reasonable time. Given the current relatively bad performance of Turing and the relative ease of supporting the julia parallel ecosystem this seems relatively higher priority here than in other PPLs (to oft set slowness).

I see your point about an inference mode but I think slightly off topic.

Examples

reduce_sum from stan: https://mc-stan.org/docs/stan-users-guide/parallelization.html#reduce-sum

does a lot of heavy lifting for you.

numpyro has approaches for this via Jax some of which are automated and some of which are not. No example pinned down though. It has the explicit plate context which seems like reduce_sum in stan but I have seen nothing that suggests it does in fact work in parallel.

SamuelBrand1 commented 1 month ago

The reduce_sum approach is the one thats easy to implement directly? In what you linked stan is offering the user the chance to rewrite their code so that conditionally independent log post density sums can get parallelised whilst killing the syntactic sugar of the stan language... I don't see whats different to using Floops here?

SamuelBrand1 commented 1 month ago

e.g they convert:

data {
  int N;
  array[N] int y;
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  beta ~ std_normal();
  y ~ bernoulli_logit(beta[1] + beta[2] * x);
}

into

functions {
  real partial_sum(array[] int y_slice,
                   int start, int end,
                   vector x,
                   vector beta) {
    return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
  }
}
data {
  int N;
  array[N] int y;
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  int grainsize = 1;
  beta ~ std_normal();
  target += reduce_sum(partial_sum, y,
                       grainsize,
                       x, beta);
}
SamuelBrand1 commented 1 month ago

Whereas in the normal Turing -> Floops example (albeit for MvNormal)

@model function test_model(n::Integer)
    μ ~ Normal(1.0, 1.0)
    σ ~ truncated(Normal(0., 1.), 0, Inf)
    y ~ MvNormal(μ * ones(n), σ * ones(n))
end

goes to

@model function test_model_floops(y, n::Integer)
    μ ~ Normal(1.0, 1.0)
    σ ~ truncated(Normal(0., 1.), 0, Inf)

    lls = Vector{eltype(μ)}(undef, n)
    let v = [μ, σ]
        @floop for i in axes(y)
            lls[i] = logpdf(Normal(v[1], v[2]), y[i])
        end
    end
    Turing.@addlogprob! sum(lls)
end

Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?

SamuelBrand1 commented 1 month ago

numpyro has approaches for this via Jax some of which are automated and some of which are not. No example pinned down though.

@dylanhmorris , do you have any cool examples here?

seabbs commented 1 month ago

Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?

This example is extremely limited vs what you can do in stan (i.e https://github.com/epinowcast/epinowcast/blob/886b45cb4bc5f338fa53d22e83d25335c55b1a4a/inst/stan/epinowcast.stan#L400 which runs all of the complicated obs model in parallel -i.e effectively dispatching over submodels).

I don't think our current approach works naturally with rephashsing everything as obs ~ lots of stuff hence the suggestion that we want to target being able to do this over submodels.

seabbs commented 1 month ago

@SamuelBrand1 and I had a f2f with the conclusion being that in the first instance we are aiming to check if @submodel can be dispatched in parallel.