lindeloev / mcp

Regression with Multiple Change Points
http://lindeloev.github.io/mcp
106 stars 19 forks source link

stan backend #100

Open lindeloev opened 3 years ago

lindeloev commented 3 years ago

mcp 2.0 will support stan in addition to JAGS. It is far out in the future but this issue collects working points.

jpzhangvincent commented 3 years ago

Awesome. Excited to see this on the roadmap. I'd love to contribute to this while still learning Bayesian modeling. Do you have any suggestion or contributor guide? I would be interested in implementing the python version with the PyMC3 backend as well.

lindeloev commented 3 years ago

Thanks, @jpzhangvincent, that would be great! I think getting it to work is simply a matter of (a) re-writing a few JAGS models as stan models and learn if they work well and (b) write an R function that generate these from mcps internal representation of the model. I could really use some input on (a) here as my stan skills are limited.

mcp is under heavy internal restructuring and a few breaking changes, most of which is tracked in issue #90. I think it makes sense to wait until after that release when things hopefully settle down. But I think the JAGS-part is finished now. mcp 0.4 takes formulas like this:

model = list(
  y ~ 1 + x:group,
  ~ 0 + x,
  ~ 1 + sigma(1 + group)
)

which for data like

> head(df)
  x group         y          z
1 1     A -1.431554 -5.9042791
2 2     B 12.819796  1.6075971
3 3     C 17.218474  4.8689988
4 4     D  9.243459 -2.1581639
5 5     A  9.609940 10.1076712
6 6     B  9.544842  0.2298296

generates JAGS code like this:

model {
  # mcp helper values
  cp_0 = MINX
  cp_3 = MAXX

  # Priors for population-level effects
  cp_1 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_0, MAXX)
  cp_2 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_1, MAXX)
  Intercept_1 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  xgroupA_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupB_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupC_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupD_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  sigma_1 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  x_2 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  Intercept_3 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  sigma_3 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  sigma_groupB_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupC_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupD_3 ~ dt(0, 1/(SDLINKY)^2, 3) 

  # Model and likelihood
  for (i_ in 1:length(x)) {
    # par_x local to each segment
    x_local_1_[i_] = min(x[i_], cp_1)
    x_local_2_[i_] = min(x[i_], cp_2) - cp_1
    x_local_3_[i_] = min(x[i_], cp_3) - cp_2

    # Formula for mu
    mu_[i_] =

      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(1)], c(Intercept_1)) * 1 + 
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(2, 3, 4, 5)], c(xgroupA_1, xgroupB_1, xgroupC_1, xgroupD_1)) * x_local_1_[i_] + 

      # Segment 2: y ~ 10 + x
      (x[i_] >= cp_1) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(7)], c(x_2)) * x_local_2_[i_] + 

      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(8)], c(Intercept_3)) * 1

    # Formula for sigma
    sigma_[i_] = max(10^-9, sigma_tmp[i_])  # Count negative sigma as just-above-zero sigma
    sigma_tmp[i_] =  
      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(6)], c(sigma_1)) * 1 + 

      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(9, 10, 11, 12)], c(sigma_3, sigma_groupB_3, sigma_groupC_3, sigma_groupD_3)) * 1

    # Likelihood and log-density for family = gaussian()
    y[i_] ~ dnorm((mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
    loglik_[i_] = logdensity.norm(y[i_], (mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
  }
}

Here, rhsdata is model.matrix but with x factored out of all terms. x is then "factored in" in JAGS, as you can see. inprod is simply equivalent to %*%* in base R.

Some of the work points for generating an equivalent stan model are:

  1. I think some of the priors can be dropped in stan (JAGS requires priors for everything).
  2. I think stan allows for vectorizing, so we can get rid of the for-loop.
  3. I have to learn more stan to see if some of it can be moved to a "data" chunk, etc.
  4. There are many identical ways to represent the formula-part, but JAGS samples considerably faster for this particular one. I'd like to see if stan is more robust so that we needn't have multiple lines of code for each segment.
  5. In general, how can this be made to run the most efficient in stan? Can we use some of the new primitives, can we make a model that runs on GPU, etc.?

Would love any tips, example stan models, or thoughts!

mattansb commented 3 years ago

As far as dependencies, you can:

  1. Have JAGS/Stan as suggested
  2. On startup
    • If neither is installed, give the user a message.
    • If only one is installed, set some options() to use that one.