pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 236 forks source link

Use Cholesky decomp instead of inverting kernel #1688

Closed DanWaxman closed 10 months ago

DanWaxman commented 10 months ago

For the GP example, the training covariance matrix is currently inverted; it's faster and more numerically stable to use the Cholesky decomposition instead. It's a rather small change, but people tend to copy-paste examples, so I think a proper implementation is important here.

In the example, I've checked that the results are numerically similar (slightly different, I think in part to sampling):

vmap_args = (
    random.split(rng_key_predict, samples["kernel_var"].shape[0]),
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
)
means, predictions, K = vmap(
    lambda rng_key, var, length, noise: predict(
        rng_key, X, Y, X_test, var, length, noise
    )
)(*vmap_args)

means_2, predictions_2, K_2 = vmap(
    lambda rng_key, var, length, noise: predict_2(
        rng_key, X, Y, X_test, var, length, noise
    )
)(*vmap_args)

mean_prediction = np.mean(means, axis=0)
mean_prediction_2 = np.mean(means_2, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
percentiles_2 = np.percentile(predictions_2, [5.0, 95.0], axis=0)
print(
    "means np.allclose?",
    np.allclose(mean_prediction, mean_prediction_2),
    " | Max difference:",
    np.max(np.abs(mean_prediction - mean_prediction_2)),
)
print("Max difference in percentiles:", np.max(np.abs(percentiles - percentiles_2)))

gives

>>> means close? False  | Max difference: 4.529953e-06
>>> Max difference in percentiles: 0.02630336433649072
fehiepsi commented 10 months ago

Thanks for the contribution, @DanWaxman! Yes, in practice, it is better to use Cholesky. Could you add an arg use_cholesky=False/True to reflect two implementations? (the cholesky implementation rarely appears in GP textbooks.)

DanWaxman commented 10 months ago

Done, thanks @fehiepsi! I made use_cholesky=True the default, and used command line option --no_cholesky due to the well-known issue of parsing bools using argparse.