stan-dev / rstan

RStan, the R interface to Stan
https://mc-stan.org
1.04k stars 269 forks source link

Is there any command in rstan package that does the same thing as uniroot() just for R #394

Open jasonqiangguo opened 7 years ago

jasonqiangguo commented 7 years ago

Hi, I am running an MCMC on R and with Gibbs sampler method in R code I used an R function uniroot() which helps me search the interval from lower to upper for a root (i.e., zero) of the function f with respect to its first argument. I wonder if there is a function contained in RStan function that can do similar things. Thank you very much.

bob-carpenter commented 7 years ago

Not yet, but Charles Margossian is working on a root finder for an upcoming release.

On Feb 20, 2017, at 9:35 AM, Jason Qiang Guo notifications@github.com wrote:

Hi, I am running an MCMC on R and with Gibbs sampler method in R code I used an R function uniroot() which helps me search the interval from lower to upper for a root (i.e., zero) of the function f with respect to its first argument. I wonder if there is a function contained in RStan function that can do similar things. Thank you very much.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub, or mute the thread.

simpsonm commented 7 years ago

I happened to have implemented this function myself by translating the C code used in R's uniroot() into Stan. It's a bit awkward since you can't pass a function as an argument, so you have to change the function for each application, but it works. Below is some Stan code that uses the function to estimate a finite mixture of normals with known mixture weights using estimates of several quantiles of that mixture along with estimated SEs for those quantiles. The "uniroot" function is zeroin() - you should be able to modify it for your own application.

I'm not sure if this applies generally, but I had trouble getting Stan to work well with any model that used the root solver in its definition. Poor adaptation and divergent transitions were both common. So good luck I guess!

Stan code (yes all of the input validation is necessary):

functions {
  // cdf for mixture of normals
  real cdf_fun(real x, vector mn, vector sdev, vector wts){
    real out;
    vector[num_elements(mn)] cdfs;
    for(i in 1:num_elements(mn)){
      cdfs[i] = normal_cdf(x, mn[i], sdev[i]);
    }
    out = dot_product(cdfs, wts);
    return out;
  }

  // translation of zeroin.c from http://www.netlib.org/c/brent.shar
  /** Returns approximation to the target quantile of the cdf we care about
   * mathematically, finds root of:
   *       f(x) = cdf_fun(x, mn, sdev, wts) - prob
   *
   * @param ax      lower bound
   * @param bx      upper bound
   * @param tol     desired tolerance level
   * @param mn      mean parameters of target cdf
   * @param sdev    sd parameters of target cdf
   * @param prob    target probability
   *
   * @return b      approximation of the target quantile
   *
   */
  real zeroin(real ax, real bx, real tol, vector mn, vector sdev, vector wts, real prob){
    real a; real b; real fa; real fb; real c; real fc; int whilecond; real eps;
    // input validation on all inputs
    // upper and lower boounds need be finite real numbers
    if(is_nan(ax) || is_nan(bx) || is_inf(ax) || is_inf(bx))
      reject("ax and bx must both be finite numbers; ax = ", ax, " bx = ", bx);
    // tol needs to be a strictly positive real number
    if(is_nan(tol) || is_inf(tol) || tol <= 0)
      reject("tol must be a finte strictly positive real number; tol = ", tol);
    // elements of mn need to be finite real numbers
    // elements of sdev need to be strictly positive finite real numbers
    for(i in 1:num_elements(mn)){
      if(is_nan(mn[i]) || is_inf(mn[i]))
    reject("mn[", i, "] must be a finite real number; mn[", i, "] = ", mn[i]);
      if(is_nan(sdev[i]) || is_inf(sdev[i]) || sdev[i] <= 0)      
    reject("sdev[", i, "] must be a finite, positive real number; sdev[", i, "] = ", sdev[i]);
    }
    // wts must be real numbers between 0 and 1
    for(i in 1:num_elements(wts))
      if(is_nan(wts[i]) || wts[i] < 0.0 || wts[i] > 1.0)
    reject("wts[", i, "] must a real number between 0 and 1; wts[", i, "] = ", wts[i]);
    // wts must sum to 1
    if(fabs(sum(wts) - 1.0) > 0.00001)
      reject("wts must sum to 1; sum(wts) = ", sum(wts));
    // dimension of the wts must be the same as the means
    if(num_elements(mn) != num_elements(wts))
      reject("mn and wts must be vectors of the same length; length(mn) = ",
         num_elements(mn), ", length(wts) = ", num_elements(wts));
    if(num_elements(mn) != num_elements(sdev))
      reject("mn and sdev must be vectors of the same length; length(mn) = ",
         num_elements(mn), ", length(sdev) = ", num_elements(sdev));
    // the target probability must be a real number strictly between 0 and 1
    if(is_nan(prob) || prob <= 0.0 || prob >= 1.0)
      reject("prob must be a real number strictly between 0 and 1; prob = ", prob);
    eps = machine_precision();
    a = ax; b = bx;
    fa = cdf_fun(a, mn, sdev, wts) - prob;
    fb = cdf_fun(b, mn, sdev, wts) - prob;
    c = a; fc = fa; whilecond = 1;
    while(whilecond == 1){
      real prev_step; real tol_act; real new_step; real cb; real t1; real p; real q; real t2;
      prev_step = b - a;
      if(fabs(fc) < fabs(fb)){
    a = b; b = c; c = a; fa = fb; fb = fc; fc = fa;
      }
      tol_act = 2*eps*fabs(b) + tol/2;
      new_step = (c-b)/2.0;
      if(fabs(new_step) <= tol_act || fb == 0.0){
    whilecond = 0;
      }
      if(whilecond == 1){
    if(fabs(prev_step) >= tol_act && fabs(fa) > fabs(fb)){
      cb = c - b;
      if(a == c){
        t1 = fb/fa; p = cb*t1; q = 1 - t1;
      } else {
        q = fa/fc; t1 = fb/fc; t2 = fb/fa;
        p = t1*(cb*q*(q-t1) - (b-1)*(t1 - 1));
        q = (q-1)*(t1-1)*(t2-1);
      }
      if(p > 0)
        q = -q;
      else
        p = -p;
      if(p < (0.75*cb*q - fabs(tol_act*q)/2) && p < fabs(prev_step*q/2) )
        new_step = p/q;
    }
    if(fabs(new_step) < tol_act)
      if(new_step > 0)
        new_step = tol_act;
      else
        new_step = -tol_act;
    a = b; fa = fb; b = b + new_step;
    fb = cdf_fun(b, mn, sdev, wts) - prob;
    if( (fb > 0 && fc > 0) || (fb < 0 && fc < 0)){
      c = a; fc = fa;
    }
      }
    }
    return b;
  }
}
data {
  int<lower = 1> nquant;  
  int<lower = 1> nmix;
  vector[nquant] qobs;  
  vector[nquant] qse;   
  vector<lower = 0, upper = 1>[nquant] quants;
  simplex[nmix] probs; // mixture weights assumed known
  real tol;            // desired tolerance level in zeroin function
}
parameters {
  vector[nmix] mu;  
  vector<lower = 0>[nmix] sigma;
}
transformed parameters{
  vector[nquant] qactual;
  {
    real mumin; real mumax; real sigmax;
    mumin = min(mu); mumax = max(mu); sigmax = max(sigma);
    for(i in 1:nquant){
      qactual[i] = zeroin(mumin - 3*sigmax, mumax + 3*sigmax,
              tol, mu, sigma, probs, quants[i]);
    }
  }
}
model {
  qobs ~ normal(qactual, qse);
  sigma ~ normal(0, 10);
  mu ~ normal(0, 10);
}
library(rstan)

#### simulate some data
nmix <- 2
mu <- c(0.2, 1.5)
sigma <- c(1.2, 0.2)

N <- 1000
probs <- c(0.3, 0.7)
ids <- rbinom(N, 1, probs[2]) + 1
y <- mu[ids] + sigma[ids]*rnorm(N)
quants <- c(0.3, 0.4, 0.5, 0.6, 0.7)
qest <- quantile(y, quants)

## SEs from order statistic asymptotic distribution
qse <- rep(0, length(qest))
for(i in 1:length(qse)){
  qse[i] <- sqrt(quants[i]*(1-quants[i])/N) /
    sum(dnorm(qnorm(quants[i], mus, sigma), mus, sigma) * probs)
}

standat <- list(nmix = 2, nquant = length(quants), qobs = qest, qse = qse,
                quants = quants, probs = probs, tol = .Machine$double.eps)

## fit the model
fit0 <- stan("quantmix.stan", data = standat, chains = 1, iter = 1)

fit <- stan(fit = fit0, data = standat, cores = 4)
## lots of divergent transitions, but at least zeroin works

traceplot(fit)

summary(fit)$summary