stan-dev / stanc3

The Stan transpiler (from Stan to C++ and beyond).
BSD 3-Clause "New" or "Revised" License
138 stars 44 forks source link

Allow truncation syntax in vectorized sampling statements #1263

Closed WardBrian closed 1 year ago

WardBrian commented 1 year ago

Closes #788. This allows statements like

  y ~ normal(mu1, sigma1) T[L, U];

In place of

  for (n in 1 : N) {
    y[n] ~ normal(mu2, sigma2) T[L, U];
  }

We previously translated the loop into (MIR syntax):

for(n in 1:N) {
    if((y[n] < L)) 
        target += FnNegInf__(); 
    else if((y[n] > U)) 
        target += FnNegInf__(); 
    else 
        target += PMinus__(log_diff_exp(normal_cdf_log(U, mu2, sigma2), normal_cdf_log(L, mu2, sigma2)));
    target += normal_lupdf(y[n], mu2, sigma2);
}

We now turn the vectorized version into

if((min(y) < L)) 
    target += FnNegInf__(); 
else if((max(y) > U)) 
    target += FnNegInf__(); 
else 
    target += PMinus__((log_diff_exp(normal_cdf_log(U, mu1, sigma1), normal_cdf_log(L, mu1, sigma1)) * size(y)));

target += normal_lupdf(y, mu1, sigma1);

NB the calls to min, max, and the multiplication by size(y). This formulation is based on #788 and the linked thread, but also to me seems more or less necessarily equivalent to the for loop.

Submission Checklist

Release notes

Vectorized sampling statements can now be used with the truncation syntax T[ , ]

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to license the submitted work under the BSD 3-clause license (https://opensource.org/licenses/BSD-3-Clause)

bob-carpenter commented 1 year ago

I agree this is going to be super useful. And thanks for catching the alternative vectorization cases, @nhuurre.

WardBrian commented 1 year ago

@nhuurre great catches!

My latest two commits:

  1. Make it so the truncation is only multiplied by size(y) if none of the later args are vectors
  2. Generate a for loop in the T[U, L] vector case.
  3. move the lupdf statement above this logic. I believe this is sound (they're both just accumulation statements), my reasoning is that item 2 here could generate a weird error if size(mu) != size(sigma), but the normal_lupdf call will necessarily check this and produce a nice error if it is not true.
WardBrian commented 1 year ago

Thanks @nhuurre. I just circled back to improve the error message which will show if a vectorized lcdf is not available and add a test for that.

I'll merge after the tests pass.