stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
736 stars 185 forks source link

Implement closures #2197

Open bbbales2 opened 3 years ago

bbbales2 commented 3 years ago

Description

This is a placeholder issue for implementing closures as described in https://github.com/stan-dev/design-docs/blob/master/designs/0004-closures-fun-types.md

I say placeholder cause that's the main specification, and there will also be associated pull requests in stanc3.

Current Version:

v3.3.0

nhuurre commented 3 years ago

The stanc3 pull request is open: stan-dev/stanc3#742 Probably has lots of bugs but it's good enough that I can compile Lotka-Volterra model:

data {
  int<lower = 0> N;          // number of measurement times
  real ts[N];                // measurement times > 0
  vector[2] y_init;            // initial measured populations
  real<lower = 0> y[N, 2];   // measured populations
}
parameters {
  real<lower = 0> alpha;
  real<lower = 0> beta;
  real<lower = 0> gamma;
  real<lower = 0> delta;
  vector<lower = 0>[2] z_init;  // initial population
  real<lower = 0> sigma[2];   // measurement errors
}
model {
  functions
  vector dz_dt(real t,      // time
               vector z     // system state {prey, predator}
               ) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return [ du_dt, dv_dt ]';
  }
  vector[2] z[N]
  = ode_bdf(dz_dt, z_init, 0., ts);

  alpha ~ normal(1, 0.5);
  gamma ~ normal(1, 0.5);
  beta ~ normal(0.05, 0.05);
  delta ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(log(10), 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}
bbbales2 commented 3 years ago

I converted the sir model over here to practice with this. (@rok-cesnovar thanks for the binaries)

It is an absolute delight to not worry about passing in parameters or data.

A few comments:

  1. I get error when I define a function in the transformed parameters block (example model included down the page and data at the bottom):
sir_negbin_lambda.hpp:150:7: error: redefinition of 'sir_L23C9_cfunctor__'
class sir_L23C9_cfunctor__ {
      ^
sir_negbin_lambda.hpp:105:7: note: previous definition is here
class sir_L23C9_cfunctor__ {
      ^
sir_negbin_lambda.hpp:244:1: error: redefinition of 'sir_L23C9_impl__'
sir_L23C9_impl__(const int& N, const T1__& beta, const T2__& gamma,
^
sir_negbin_lambda.hpp:190:1: note: previous definition is here
sir_L23C9_impl__(const int& N, const T1__& beta, const T2__& gamma,
  1. If there is a parameter y defined in the transformed parameters block, I cannot have an argument to my function named y. I guess to allow this we would need to add overloading to the language, which is out of the scope of this?

  2. I'm not sold on the naming. What about:

    function dz_dt = (real t, vector z) -> vector { ... }

    I just copied the C++ lambda return type syntax. I wouldn't mind automatically deduced types there. Would this then make it possible to pass the functions as arguments without first defining them as variabies, i.e., ode_bdf((real t, vector z) -> vector { ... }, ...)?

Here is the sir model:

data {
  int<lower=1> n_days;
  vector[3] y0;
  real t0;
  real ts[n_days];
  int N;
  int cases[n_days];
}
transformed data {
  real x_r[0];
  int x_i[1] = { N };
}
parameters {
  real<lower=0> gamma;
  real<lower=0> beta;
  real<lower=0> phi_inv;
}
transformed parameters{
  vector[3] y[n_days];
  real phi = 1. / phi_inv;

  functions
  vector sir(real t, vector x) {
    real S = x[1];
    real I = x[2];
    real R = x[3];

    real dS_dt = -beta * I * S / N;
    real dI_dt =  beta * I * S / N - gamma * I;
    real dR_dt =  gamma * I;

    return [ dS_dt, dI_dt, dR_dt ]';
  }

  y = ode_rk45(sir, y0, t0, ts);
}
model {
  //priors                                                                                                                                                                                                                                                                     
  beta ~ normal(2, 1);
  gamma ~ normal(0.4, 0.5);
  phi_inv ~ exponential(5);

  //sampling distribution                                                                                                                                                                                                                                                      
  cases ~ neg_binomial_2(y[, 2], phi);
}

generated quantities {
  real R0 = beta / gamma;
  real recovery_time = 1 / gamma;
  real pred_cases[n_days];
  pred_cases = neg_binomial_2_rng(y[, 2], phi);
}

Lotka volterra model above but with rk45 solver:

data {
  int<lower = 0> N;          // number of measurement times                                                                                                                         
  real ts[N];                // measurement times > 0                                                                                                                               
  vector[2] y_init;            // initial measured populations                                                                                                                      
  real<lower = 0> y[N, 2];   // measured populations                                                                                                                                
}
parameters {
  real<lower = 0> alpha;
  real<lower = 0> beta;
  real<lower = 0> gamma;
  real<lower = 0> delta;
  vector<lower = 0>[2] z_init;  // initial population                                                                                                                               
  real<lower = 0> sigma[2];   // measurement errors                                                                                                                                 
}
model {
  functions
  vector dz_dt(real t,      // time                                                                                                                                                 
               vector z     // system state {prey, predator}                                                                                                                        
               ) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return [ du_dt, dv_dt ]';
  }

  vector[2] z[N] = ode_rk45_tol(dz_dt, z_init, 0., ts,
                                1e-5, 1e-3, 500);

  alpha ~ normal(1, 0.5);
  gamma ~ normal(1, 0.5);
  beta ~ normal(0.05, 0.05);
  delta ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(log(10), 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}

Lotka volterra data: lv.dat.txt Sir data: sir.dat.txt (used the code here to generate the data file: https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html)

bbbales2 commented 3 years ago

@nhuurre does this accurately reflect the todo on this:

I don't know how we could get this to work with map_rect because I do not know how it works. @wds15 do you have the time or inclination to inspect this?

We're trying to plan out 2.26 (https://github.com/stan-dev/cmdstan/issues/957) and this is a feature that's pretty far into development.

nhuurre commented 3 years ago

Stan Math has integrate_dae function but it's not exposed in the language so I guess it doesn't need closure support.

I think the only tricky part with map_rect is supporting MPI.

I deviated quite a bit from the design-doc. Main points:

bbbales2 commented 3 years ago

I deviated quite a bit from the design-doc

Cool. Reasons seem good to me. Are the known bugs fixed in the compiler now? I vaguely remember I got an ode and a reduce_sum running, but I might be misremembering the second thing :D.

If it is, I wanna build a cmdstan tarball with the mac binaries so that @bob-carpenter can test. He said he'd be down to check this against the design doc (and presumably update the design doc or suggest changes).

nhuurre commented 3 years ago

Are the known bugs fixed in the compiler now?

I did push a fix for the redcard model.

Now, looking into special suffixes, I see that this (invalid) model causes a C++ compiler error.

functions {
  real foo(real(real) f) {
    return f(0) + f(1);
  }
  real bar_rng(real s) {
    return normal_rng(s,1);
  }
}
transformed data {
  real z = foo(bar_rng);
}

But models that should compile still do compile correctly as far as I know.

bbbales2 commented 3 years ago

Cool I'll build a tarball then!

bbbales2 commented 3 years ago

Yo @bob-carpenter this contains a cmdstan & stanc3 that can compile closures: cmdstan-closures.zip

The goal here is to check the implementation with the design doc, and update stuff as needed.

To keep the zip file under 10 megabytes so Github would let me upload it I removed the stan and math repos from there. Once you unzip that, you should be able to go into the folder and run the command bash checkout.sh and it'll check out the stan and math repos you need to make this run.

Implementation of that script is just this:

git clone --depth=1 --single-branch --branch=develop https://github.com/stan-dev/stan stan
git clone --depth=1 --single-branch --branch=feature/ode-closures https://github.com/nhuurre/math.git stan/lib/stan_math

There's a Lotka Volterra one in this comment that you can use to test your build is working.

bbbales2 commented 3 years ago

@nhuurre I'm not sure what to do about map_rect still. I put up a quick pull for integrate_1d (https://github.com/stan-dev/math/pull/2397) and it probably won't be too hard to do a similar one for the algebra solvers.

I guess the worst thing that could happen is we make closures that aren't MPI compatible, or are just really bad in some way. A less worse thing would be that we make sure our closures are MPI-friendly, but just don't implement them for map_rect. Best would be we implement map_rect.

@wds15 I don't want to distract you from adjoint ODEs, so the easiest thing I can think of is we do a Hangouts next week or something (@nhuurre you're invited too ofc. if you wanna come). I can walk you through closures and then we can talk about what it'd take to make them map_rect friendly?

wds15 commented 3 years ago

I need to catch up on this. It looks super cool. Can we drop mpi map eect support? Is that an option?

EDIT: So I am scheduled for a hangout next week?

bbbales2 commented 3 years ago

So I am scheduled for a hangout next week?

Pre-Stan meeting work? Or I'm free any day at the Stan time or the hour before.

Can we drop mpi map eect support?

I wouldn't mind doing that. I wanna just give it one try before going that direction, but I also don't wanna keep delaying closures. It would be pretty easy to just throw runtime errors in Math, and presumably stanc3 could have an error on compile ("map_rect does not support closures" or something), so it wouldn't be the worst outcome in the world. But it also might not be that bad to do map_rect, I just don't know how it works.

wds15 commented 3 years ago

Today or tomorrow 10h am your time / 16h my time would work. Maybe easiest to email me? CU

rok-cesnovar commented 3 years ago

I would vote against removing MPI map_rect support.

What we could do if making closures work here for MPI would be a ton of work is make stanc3 aware that its compiling for MPI, like we do for OpenCL (when STAN_OPENCL is set, stanc3 is called with --use-opencl) and then report a semantic error if closures are used with map_rect if MPI is on.

wds15 commented 3 years ago

I am not saying to cancel MPI map_rect in its current form. It‘s just that we may not want closures to be delayed significantly due to MPI stuff. Thus, we could roll out closures without MPI support (meaning that map_rect MPI will not work).

(Map_rect with MPI would be quite nice, of course)

rok-cesnovar commented 3 years ago

It‘s just that we may not want closures to be delayed significantly due to MPI stuff. Thus, we could roll out closures without MPI support (meaning that map_rect MPI will not work).

Oh, I am definitely on board with that.

bbbales2 commented 3 years ago

Talked to @wds15 this morning about map_rect. It looks difficult. I think we would need to be able to serialize the functors somehow and also handle the optimization where data is only shipped once to the worker nodes (which neither thing straightforward).

Let's do integrate_1d first (#2397), and then algebra_solver, and then come back to make a try at map_rect.

wds15 commented 3 years ago

@nhuurre Have you already done a bit of benchmarking? This feature screems for application with ODEs... but these are super performance critical and I would be scared to put any friction into the turning wheels of the ODE integrators.

Other than that this looks amazing (and there are other domains than ODEs where this is highly useful)!

SteveBronder commented 3 years ago

Not a biggie but jw, is it possible in the spec to have anonymous lambdas like

  vector[2] z[N] = ode_rk45_tol(functions vector (real t, vector z) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return [ du_dt, dv_dt ]';
  }, z_init, 0., ts, 1e-5, 1e-3, 500);

If the compiler could handle type checking of the inputs and return in C++ this would just translate to

  std::vector<Eigen::Matrix<scalar_t, -1, -1>> z = 
    ode_rk45_tol([](auto&& t, auto&& z) {
      // compiled body same as current impl
  }, z_init, 0., ts, 1e-5, 1e-3, 500);