avehtari / casestudies

26 stars 6 forks source link

new diagSPD_periodic != old diagSPD_periodic #3

Closed mike-lawrence closed 1 year ago

mike-lawrence commented 1 year ago

The older diagSPD_periodic is expressed as:

  vector diagSPD_periodic(real gpscale, real lscale, int M) {
    real a = 1/lscale^2;
    int one_to_M[M];
    for (m in 1:M) one_to_M[m] = m;
    vector[M] q = sqrt(gpscale^2 * 2 / exp(a) * to_vector(modified_bessel_first_kind(one_to_M, a)));
    return append_row(q,q);
  }

which, constructing one_to_M in-place with linspaced_int_array becomes:

  vector diagSPD_periodic(real gpscale, real lscale, int M) {
    real a = 1/lscale^2;
    vector[M] q = sqrt(gpscale^2 * 2 / exp(a) * to_vector(modified_bessel_first_kind(linspaced_int_array(M,1,M), a)));
    return append_row(q,q);
  }

And the new expression is:

vector diagSPD_periodic(real alpha, real rho, int M) {
  real a = 1/rho^2;
  vector[M] q = exp(log(alpha) + 0.5 * (log(2) - a + to_vector(log_modified_bessel_first_kind(linspaced_int_array(M, 1, M), a))));
  return append_row(q,q);
}

Renaming to common argument names and assigning to_vector(modified_bessel_first_kind(one_to_M, a))) to a variable c, the core difference is in q:

//old:
vector[M] q = sqrt(alpha^2 * 2 / exp(a) * c) ;
// new:
vector[M] q = exp(log(alpha) + 0.5 * (log(2) - a + log(c))) ; //n.b. log(c) = to_vector(log_modified_bessel_first_kind(...

However, so far as my admittedly amateur mathematical expertise can discern, these are not equivalent expressions.

I asked chatGPT4 to verify my analysis:

Simplifying the First Expression

The first expression is:

$$ \exp(\log(b) + 0.5 \times (\log(2) - a + c)) $$

Step 1: Distribute the $0.5$ inside the parentheses.

$$ \exp(\log(b) + 0.5 \log(2) - 0.5a + 0.5c) $$

Step 2: Combine the exponential and logarithmic functions.

$$ \exp(\log(b)) \times \exp(0.5 \log(2)) \times \exp(-0.5a) \times \exp(0.5c) $$

Step 3: Further simplify.

$$ b \times \sqrt{2} \times \frac{1}{\sqrt{\exp(a)}} \times \sqrt{\exp(c)} $$

Simplifying the Second Expression

The second expression is:

$$ \sqrt{b^2 \times \frac{2}{\exp(a)} \times \log(c)} $$

Step 1: Take the square root of each term inside the square root.

$$ \sqrt{b^2} \times \sqrt{\frac{2}{\exp(a)}} \times \sqrt{\log(c)} $$

Step 2: Further simplify.

$$ b \times \sqrt{\frac{2}{\exp(a)}} \times \sqrt{\log(c)} $$

Step 3: Write $\sqrt{\frac{2}{\exp(a)}}$ as $\frac{1}{\sqrt{\exp(a)}} \times \sqrt{2}$.

$$ b \times \frac{1}{\sqrt{\exp(a)}} \times \sqrt{2} \times \sqrt{\log(c)} $$

Comparison

After simplification, the first expression becomes:

$$ b \times \sqrt{2} \times \frac{1}{\sqrt{\exp(a)}} \times \sqrt{\exp(c)} $$

And the second expression becomes:

$$ b \times \frac{1}{\sqrt{\exp(a)}} \times \sqrt{2} \times \sqrt{\log(c)} $$

The two expressions are not equivalent because the terms $\sqrt{\exp(c)}$ and $\sqrt{\log(c)}$ are not the same for all values of $c$.

mike-lawrence commented 1 year ago

I think the proper translation of the old form to one that is numerically-stable is:

vector[M] q = exp( 0.5 * (2*log(alpha) + log(2) + log(c) - a) ) ;

which in full form is:

vector diagSPD_periodic(real alpha, real rho, int M) {
    real a = 1/square(rho);
    vector[M] q = exp(
        0.5 * (
            2*log(alpha)
            + log(2)
            + to_vector(log_modified_bessel_first_kind(linspaced_int_array(M, 1, M), a))
            - a
        )
    ) ;
    return append_row(q,q);
}
nsiccha commented 1 year ago

Interesting. Wolfram Alpha says the expressions are identical (under our assumptions). What did you ask chatgpt?

mike-lawrence commented 1 year ago

Ah. Apologies. Indeed, under assumptions that a/alpha/c are positive quantities (which will hold as used), they do indeed seem to be equivalent:

a = exp(rnorm(1)) 
alpha = exp(rnorm(1)) 
c = exp(rnorm(1)) 

x = exp(log(alpha) + 0.5 * (log(2) - a + log(c)))
y = sqrt(alpha^2 * 2 / exp(a) * c)
x - y # == 0

Sorry for the distraction!

mike-lawrence commented 1 year ago

Oh! I see now why chatGPT gave me the wrong answer: I gave it the wrong input 🤦

I mixed up which had log_bessel...() and which had bessel...(), and therefore which had log(c) at the end and which had c at the end.