stan-dev / stanc3

The Stan transpiler (from Stan to C++ and beyond).
BSD 3-Clause "New" or "Revised" License
140 stars 44 forks source link

Add new ODE interface #481

Closed bbbales2 closed 4 years ago

bbbales2 commented 4 years ago

Could we get a prototype for the new ODE stuff similar to what we have for reduce_sum (https://github.com/stan-dev/stanc3/pull/451)? No need for super error checking and whatnot at this point. Just wanna start playing with ODEs in the new interface.

@rok-cesnovar, @rybern , whoever feels like putting it together.

The mission here is similar. Instead of the ODE signature:

real[ , ] integrate_ode_bdf(function ode_rhs, real[] initial_state,
                            real initial_time, real[] times,
                            real[] theta, real[] x_r, int[] x_i,
                            data real rel_tol, data real abs_tol, data int max_num_steps)

with ode right hand side function:

real[] my_ode_func(real, real[], real[], data real[], data int[]):real[]

we want (these names are all subject to change I just wanted something short for testing):

real[ , ] ode_bdf(function ode_rhs, real[] initial_state,
                  real initial_time, real[] times,
                  data real rel_tol, data real abs_tol, data int max_num_steps,
                  T1 arg1, T2 arg2, ...)

where the ode right hand side function looks like:

real[] my_ode_func(real time, real[] state, T1 arg1, T2 arg2, ...)

The real_tol/abs_tol/max_num_steps args are no longer optional with the new form.

The C++ signature of the new ode integrate function is:

template <typename F, typename T_initial, typename T_t0, typename T_ts, typename... T_Args>
std::vector<std::vector<typename stan::return_type<T_initial, T_t0, T_ts, T_Args...>::type>>
ode_bdf(const F& f, const std::vector<T_initial>& y0, const T_t0& t0,
        const std::vector<T_ts>& ts,
        double relative_tolerance,
        double absolute_tolerance,
        long int max_num_steps,
        std::ostream* msgs,
        const T_Args&... args)

An example C++ functor that will work with the new integrator signature is:

struct CosArg1 {
  template <typename T0, typename T1, typename... T_Args>
  inline std::vector<typename stan::return_type<T1, T_Args...>::type>
  operator()(const T0& t, const std::vector<T1>& y,
             const T_Args&... a, std::ostream* msgs) const {
    return ...;
  }
};

Here is an old Stan model (taken from Bob's Lotka-Volterra case study -- https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html):

functions {
  real[] dz_dt(real t,       // time
               real[] z,     // system state {prey, predator}
               real[] theta, // parameters
               real[] x_r,   // unused data
               int[] x_i) {
    real u = z[1];
    real v = z[2];

    real alpha = theta[1];
    real beta = theta[2];
    real gamma = theta[3];
    real delta = theta[4];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;
    return { du_dt, dv_dt };
  }
}
data {
  int<lower = 0> N;           // number of measurement times
  real ts[N];                 // measurement times > 0
  real y_init[2];             // initial measured populations
  real<lower = 0> y[N, 2];    // measured populations
}
parameters {
  real<lower = 0> theta[4];   // { alpha, beta, gamma, delta }
  real<lower = 0> z_init[2];  // initial population
  real<lower = 0> sigma[2];   // measurement errors
}
transformed parameters {
  real z[N, 2]
  = integrate_ode_rk45(dz_dt, z_init, 0, ts, theta,
                       rep_array(0.0, 0), rep_array(0, 0),
                       1e-5, 1e-3, 5e2);
}
model {
  theta[{1, 3}] ~ normal(1, 0.5);
  theta[{2, 4}] ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(log(10), 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}

Here is code that should compile with the new interface:

functions {
  real[] dz_dt(real t,       // time
               real[] z,     // system state {prey, predator}
               real alpha,
               real beta,
               real gamma,
               real delta) {
    real u = z[1];
    real v = z[2];

    real du_dt = (alpha - beta * v) * u;
    real dv_dt = (-gamma + delta * u) * v;

    return { du_dt, dv_dt };
  }
}
data {
  int<lower = 0> N;          // number of measurement times
  real ts[N];                // measurement times > 0
  real y_init[2];            // initial measured populations
  real<lower = 0> y[N, 2];   // measured populations
}
parameters {
  real<lower = 0> alpha;
  real<lower = 0> beta;
  real<lower = 0> gamma;
  real<lower = 0> delta;
  real<lower = 0> z_init[2];  // initial population
  real<lower = 0> sigma[2];   // measurement errors
}
transformed parameters {
  real z[N, 2]
  = ode_bdf(dz_dt, z_init, 0, ts,
            1e-5, 1e-3, 5e2,
            alpha, beta, gamma, delta);
}
model {
  alpha ~ normal(1, 0.5);
  gamma ~ normal(1, 0.5);
  beta ~ normal(0.05, 0.05);
  delta ~ normal(0.05, 0.05);
  sigma ~ lognormal(-1, 1);
  z_init ~ lognormal(10, 1);
  for (k in 1:2) {
    y_init[k] ~ lognormal(log(z_init[k]), sigma[k]);
    y[ , k] ~ lognormal(log(z[, k]), sigma[k]);
  }
}

Here is data for it (this pull has the branch that should run it, but I wouldn't be surprised if it fails -- it's not the most tested: https://github.com/stan-dev/math/pull/1641):

N <- 20
ts <-
c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
y_init <-
c(30, 4)
y <-
structure(c(47.2, 70.2, 77.4, 36.3, 20.6, 18.1, 21.4, 22, 25.4, 27.1, 40.3, 57, 76.6, 52.3, 19.5, 11.2, 7.6, 14.6, 16.2, 24.7, 6.1, 9.8,
35.2, 59.4, 41.7, 19, 13, 8.3, 9.1, 7.4, 8, 12.3, 19.5, 45.7, 51.1, 29.7, 15.8, 9.7, 10.1, 8.6),
.Dim = c(20, 2))
rok-cesnovar commented 4 years ago

It should be pretty simple to get this working in an experimental manner, without any checks. A bit of reordering of what is done for #451 and that should be it.

EDIT: The only real issue with the reduce_sum code is that it doesnt pickup the calls if they are used inside other function (reduce_sum + reduce_sum for example, as + is the top level function here). I didnt bother too much with it for now, as there were some perf. issues there. But I guess that is also not a huge need for odes for the first examples?