stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
736 stars 185 forks source link

Fix Exponentially modified Gaussian overflows at lambda >= 40 #2803

Open spinkney opened 2 years ago

spinkney commented 2 years ago

Description

According to Wikipedia the exponentially modified Gaussian can be made more precise by a reparameterization and using the scaled erfc.

I have implemented the reparameterization and the issue below.

Example

library(Rcpp)
cppFunction('double my_erfc( double x){
    return erfc( x ) ;
}')

cppFunction('double my_erfcx (double x)
// from https://stackoverflow.com/questions/39777360/accurate-computation-of-scaled-complementary-error-function-erfcx
{
  double a, d, e, m, p, q, r, s, t;

  a = fmax (x, 0.0 - x); // NaN preserving absolute value computation

  /* Compute q = (a-4)/(a+4) accurately. [0,INF) -> [-1,1] */
  m = a - 4.0;
  p = a + 4.0;
  r = 1.0 / p;
  q = m * r;
  t = fma (q + 1.0, -4.0, a); 
  e = fma (q, -a, t); 
  q = fma (r, e, q); 

  /* Approximate (1+2*a)*exp(a*a)*erfc(a) as p(q)+1 for q in [-1,1] */
    p =             0x1.edcad78fc8044p-31;  //  8.9820305531190140e-10
    p = fma (p, q,  0x1.b1548f14735d1p-30); //  1.5764464777959401e-09
    p = fma (p, q, -0x1.a1ad2e6c4a7a8p-27); // -1.2155985739342269e-08
    p = fma (p, q, -0x1.1985b48f08574p-26); // -1.6386753783877791e-08
    p = fma (p, q,  0x1.c6a8093ac4f83p-24); //  1.0585794011876720e-07
    p = fma (p, q,  0x1.31c2b2b44b731p-24); //  7.1190423171700940e-08
    p = fma (p, q, -0x1.b87373facb29fp-21); // -8.2040389712752056e-07
    p = fma (p, q,  0x1.3fef1358803b7p-22); //  2.9796165315625938e-07
    p = fma (p, q,  0x1.7eec072bb0be3p-18); //  5.7059822144459833e-06
    p = fma (p, q, -0x1.78a680a741c4ap-17); // -1.1225056665965572e-05
    p = fma (p, q, -0x1.9951f39295cf4p-16); // -2.4397380523258482e-05
    p = fma (p, q,  0x1.3be1255ce180bp-13); //  1.5062307184282616e-04
    p = fma (p, q, -0x1.a1df71176b791p-13); // -1.9925728768782324e-04
    p = fma (p, q, -0x1.8d4aaa0099bc8p-11); // -7.5777369791018515e-04
    p = fma (p, q,  0x1.49c673066c831p-8);  //  5.0319701025945277e-03
    p = fma (p, q, -0x1.0962386ea02b7p-6);  // -1.6197733983519948e-02
    p = fma (p, q,  0x1.3079edf465cc3p-5);  //  3.7167515521269866e-02
    p = fma (p, q, -0x1.0fb06dfedc4ccp-4);  // -6.6330365820039094e-02
    p = fma (p, q,  0x1.7fee004e266dfp-4);  //  9.3732834999538536e-02
    p = fma (p, q, -0x1.9ddb23c3e14d2p-4);  // -1.0103906603588378e-01
    p = fma (p, q,  0x1.16ecefcfa4865p-4);  //  6.8097054254651804e-02
    p = fma (p, q,  0x1.f7f5df66fc349p-7);  //  1.5379652102610957e-02
    p = fma (p, q, -0x1.1df1ad154a27fp-3);  // -1.3962111684056208e-01
    p = fma (p, q,  0x1.dd2c8b74febf6p-3);  //  2.3299511862555250e-01

    /* Divide (1+p) by (1+2*a) ==> exp(a*a)*erfc(a) */
      d = a + 0.5;
    r = 1.0 / d;
    r = r * 0.5;
    q = fma (p, r, r); // q = (p+1)/(1+2*a)
    t = q + q;
    e = (p - q) + fma (t, -a, 1.0); // residual: (p+1)-q*(1+2*a)
    r = fma (e, r, q);

    /* Handle argument of infinity */
      if (a > 0x1.fffffffffffffp1023) r = 0.0;

    /* Handle negative arguments: erfcx(x) = 2*exp(x*x) - erfcx(|x|) */
      if (x < 0.0) {
        s = x * x;
        d = fma (x, x, -s);
        e = exp (s);
        r = e - r;
        r = fma (e, d + d, r); 
        r = r + e;
        if (e > 0x1.fffffffffffffp1023) r = e; // avoid creating NaN
      }
    return r;
}')

model_path <- cmdstanr::write_stan_file(" 
functions {
real my_exp_mod_normal_lpdf(real y, real mu, real sigma, real lambda) {
    return exp_mod_normal_lpdf(y | mu, sigma, lambda);
  }
                                        }
                                        ")
expose_cmdstanr_functions(model_path, expose_to_global_env = T)

r_exp_mod_normal_lpdf <- function(y, mu, sigma, lambda) {
  tau <- 1 / lambda
  y_std <- (y - mu) / sigma
  z <- (1 / sqrt(2)) * (sigma / tau - y_std)
  logh <- -log(sigma) - 0.5 * log(pi)
  half_log2pi <- 0.5 * log(pi) - log(2)

  if (z < 0) {
    x <- logh + log(sigma) - log(tau) + half_log2pi

    return (x + 0.5 * (sigma/tau)^2 - ((y - mu) / tau) +
              log(my_erfc(1/sqrt(2) * (sigma/tau - y_std))))

  } else if (z >= 0 && z <=  6.71e7) {
    logh <- -log(sigma) - 0.5 * log(pi)
    x <- logh - 0.5 * y_std^2 + log(sigma) - log( tau)

    return( x + half_log2pi + log(my_erfcx((1 / sqrt(2)) * (sigma / tau - y_std))))
  } 

  return (logh - 0.5 * y_std^2 - log1p(y_std * tau / sigma))
}

> my_exp_mod_normal_lpdf(1, 0, 1, 40)
[1] -Inf
> r_exp_mod_normal_lpdf(1, 0, 1, 40)
[1] -1.394277
spinkney commented 2 years ago

Having the erfcx function in Stan will also help speed up the Wiener diffusion distribution and aid in the calculation of the cdf. See Even faster and even more accurate first-passage time densities and distributions for the Wiener diffusion model by Matthias Gondan, Steven P. Blurton, Miriam Kesselmeier.

Also, if we implement a log_Phi function then using erfcx can be expanded for negative values. Below is the calculation for large negative values of log(Phi(x)) in stan-math which overflows and the calculation using erfcx:

x <- -40
log(0.5) + log(my_erfc(-(1/sqrt(2)) * x))
[1] -Inf

log(0.5) + log(my_erfcx(-(1/sqrt(2)) * x)) -  x^2 * 0.5
[1] -804.6084

This also can be used to make a log_erf function as

// log(erf(x))

real log_erf (real x) {
 return log_diff_exp(0, log(erfcx(x)) - x^2);
}
spinkney commented 2 years ago

Here's log_Phi expanded. It is equal to std_normal_lcdf in the case of positive x (allows values of x up to 38) but expands negative x. Also see https://github.com/stan-dev/math/issues/2470.

log_Phi <- function(x) {
  y <- log(0.5) + log(my_erfcx(-(1/sqrt(2)) * x)) -  x^2 * 0.5 
  if ( x < 0) {
    return(y)
 } else {
   y2  <- log(0.5) + log(my_erfcx((1/sqrt(2)) * x)) -  x^2 * 0.5 
    return( log1mexp(-y2) ) 
  }
}
log_Phi <- Vectorize(log_Phi)
Phi_stan <- Vectorize(Phi_stan)

v <- 0:20
data.table(k = v, log_phi = log_Phi(v), stan_log_phi = log(Phi_stan(v)))

     k       log_phi  stan_log_phi
 1:  0 -6.931472e-01 -6.931472e-01
 2:  1 -1.727538e-01 -1.727538e-01
 3:  2 -2.301291e-02 -2.301291e-02
 4:  3 -1.350810e-03 -1.350810e-03
 5:  4 -3.167174e-05 -3.167174e-05
 6:  5 -2.866516e-07 -2.866516e-07
 7:  6 -9.865876e-10 -9.865877e-10
 8:  7 -1.279813e-12 -1.279865e-12
 9:  8 -6.220961e-16 -6.661338e-16
10:  9 -1.128588e-19  0.000000e+00
11: 10 -7.619853e-24  0.000000e+00
12: 11 -1.910660e-28  0.000000e+00
13: 12 -1.776482e-33  0.000000e+00
14: 13 -6.117164e-39  0.000000e+00
15: 14 -7.793537e-45  0.000000e+00
16: 15 -3.670966e-51  0.000000e+00
17: 16 -6.388754e-58  0.000000e+00
18: 17 -4.105996e-65  0.000000e+00
19: 18 -9.740949e-73  0.000000e+00
20: 19 -8.527224e-81  0.000000e+00
21: 20 -2.753624e-89  0.000000e+00

> v <- -100:-30
> data.table(k = v, log_phi = log_Phi(v), stan_log_phi = log(Phi_stan(v)))
       k    log_phi stan_log_phi
 1: -100 -5005.5242         -Inf
 2:  -99 -4906.0142         -Inf
 3:  -98 -4807.5040         -Inf
 4:  -97 -4709.9938         -Inf
 5:  -96 -4613.4834         -Inf
 6:  -95 -4517.9729         -Inf
 7:  -94 -4423.4623         -Inf
 8:  -93 -4329.9517         -Inf
 9:  -92 -4237.4408         -Inf
10:  -91 -4145.9299         -Inf
11:  -90 -4055.4189         -Inf
12:  -89 -3965.9077         -Inf
13:  -88 -3877.3964         -Inf
14:  -87 -3789.8850         -Inf
15:  -86 -3703.3734         -Inf
16:  -85 -3617.8617         -Inf
17:  -84 -3533.3499         -Inf
18:  -83 -3449.8379         -Inf
19:  -82 -3367.3258         -Inf
20:  -81 -3285.8135         -Inf
21:  -80 -3205.3011         -Inf
22:  -79 -3125.7885         -Inf
23:  -78 -3047.2758         -Inf
24:  -77 -2969.7629         -Inf
25:  -76 -2893.2498         -Inf
26:  -75 -2817.7366         -Inf
27:  -74 -2743.2232         -Inf
28:  -73 -2669.7096         -Inf
29:  -72 -2597.1958         -Inf
30:  -71 -2525.6818         -Inf
31:  -70 -2455.1676         -Inf
32:  -69 -2385.6533         -Inf
33:  -68 -2317.1387         -Inf
34:  -67 -2249.6239         -Inf
35:  -66 -2183.1088         -Inf
36:  -65 -2117.5936         -Inf
37:  -64 -2053.0781         -Inf
38:  -63 -1989.5623         -Inf
39:  -62 -1927.0463         -Inf
40:  -61 -1865.5301         -Inf
41:  -60 -1805.0136         -Inf
42:  -59 -1745.4968         -Inf
43:  -58 -1686.9797         -Inf
44:  -57 -1629.4623         -Inf
45:  -56 -1572.9446         -Inf
46:  -55 -1517.4266         -Inf
47:  -54 -1462.9083         -Inf
48:  -53 -1409.3896         -Inf
49:  -52 -1356.8706         -Inf
50:  -51 -1305.3511         -Inf
51:  -50 -1254.8314         -Inf
52:  -49 -1205.3112         -Inf
53:  -48 -1156.7906         -Inf
54:  -47 -1109.2695         -Inf
55:  -46 -1062.7481         -Inf
56:  -45 -1017.2261         -Inf
57:  -44  -972.7036         -Inf
58:  -43  -929.1807         -Inf
59:  -42  -886.6572         -Inf
60:  -41  -845.1331         -Inf
61:  -40  -804.6084         -Inf
62:  -39  -765.0832         -Inf
63:  -38  -726.5572         -Inf
64:  -37  -689.0306    -689.0306
65:  -36  -652.5032    -652.5032
66:  -35  -616.9751    -616.9751
67:  -34  -582.4462    -582.4462
68:  -33  -548.9164    -548.9164
69:  -32  -516.3856    -516.3856
70:  -31  -484.8540    -484.8540
71:  -30  -454.3212    -454.3212