stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
752 stars 188 forks source link

Add special case for derivative of modified_bessel_function(0,x) to greatly improve model estimation speed #3009

Open venpopov opened 9 months ago

venpopov commented 9 months ago

Description

As described here, evaluating the derivative of modified_bessel_function(0,x) can be sped up dramatically with a simple change. Issue 3008 described how to do that for the von_mises_lpdf, where the derivative is hand-coded, but that solution won't apply to custom models that use the modified_bessel_function(0,x).

In the forward and reverse passes, the derivative of the modified_bessel_function(v,x) of general order v is calculated as:

$$ \frac{\delta Iv(x)}{\delta x} = I{v-1}(x) - \frac{v}{x}I_v(x) $$

For $I_0(x)$, this results in the calculation:

$$ \frac{\delta I0(x)}{\delta x} = I{-1}(x) - \frac{0}{x}I_0(x) $$

Since, $I{-1}(x) = I{1}(x)$ and the second term is 0 we have (see 10.29.3):

$$ \frac{\delta I0(x)}{\delta x} = I{1}(x) $$

As described here, calculating modified_bessel_function(1,x) is about 10 times faster than calculating modified_bessel_function(-1,x). Thus, the above code while applicable for any order, results in very inefficient calculation for models that use modified_bessel_function(0,x), which is the most common order (at least in my field). This is because it unnecessarily calculates $I0(x)$, even though this terms disappears, and it calculate $I{-1}(x) instead of $I_1(x)$

Example

In a model I'm currently building, which has the likelihood:

$$ f(\theta, c, k) = exp\bigg(\frac{c\ exp(y\ cos(\theta))}{2 \pi I_0(y)}\bigg)/Z(c,y) $$

after many other optimizations, now 90% of the time is spent in calculating $I_0(y)$. E.g., using the profile function of the cmdstanr package:

functions {
  real sdm_lpdf(vector y, vector mu, vector kappa) {
    profile("lpdf_be") {
      be = modified_bessel_first_kind(0, kappa);
    }
    // code for calculating the rest of the likelihood
    }
}

// other code

model {
  // other code
  profile("model_lpdf_total") {
     target += sdm_lpdf(Y | mu, kappa);
  }
 // other code
}

shows

                     name thread_id total_time forward_time reverse_time chain_stack no_chain_stack autodiff_calls no_autodiff_calls
1        model_lpdf_total         1   1233.100   76.9433000  1.15616e+03  1462836268     1462164925          60918                 1
2                 lpdf_be         1   1201.490   52.2222000  1.14927e+03   731053200              0          60918                 1

and the vast majority of that time is the reverse autodiff pass

Requested change

I envision two possibilities:

1) add a conditional statement to the fwd and rev passes that handles the derivative of the special case of modified_bessel_function(0,x)

2) replace the derivative formula with

$$ \frac{\delta Iv(x)}{\delta x} = I{v+1}(x) + \frac{v}{x}I_v(x) $$

which is equivalent (see 10.29.2) to the current statement, but will avoid the inneficient calculation for negative order. The downside of this option is two-fold - first, it still calculates the an extra bessel function, even thought it will be canceled by multiplication by 0 (is this correct? I'm not sure how autodiff handles such cases). Second, it will make the derivative of models that use $I_1(x)$ less efficient instead.

When I rerun the model with manually changing my stan installation code with option 1, I see ~4 times faster estimation of the model (e.g. from 11h down to 3h!)

Expected Output

For option 1), a possible implementation is to change the following in rev

    bvi_->adj_
        += adj_
           * (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
              + modified_bessel_first_kind(ad_ - 1, bvi_->val_));

to (described with pseudo code for conditional statements, because I don't know what is the most efficient way to code that - is it a simple if {}... else {}?)

    // if ad_ == 0
    bvi_->adj_
        += adj_ * modified_bessel_first_kind(1, bvi_->val_);

    // else
    bvi_->adj_
        += adj_
           * (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
              + modified_bessel_first_kind(ad_ - 1, bvi_->val_));

And change the following in fwd:


  // if v == 0
  return fvar<T>(z.d_ * modified_bessel_first_kind(1, z.val_));

  // else
  return fvar<T>(modified_bessel_first_kind_z,
                 -v * z.d_ * modified_bessel_first_kind_z / z.val_
                     + z.d_ * modified_bessel_first_kind(v - 1, z.val_));

Current Version:

v4.8.0

venpopov commented 9 months ago

@andrjohns I can try to implement this after our discussion in the other issue. Do you think that the conditional approach checking if the order of the bessel function is 0 is appropriate?