tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Implementation of Cholesky rank-1 update is orders of magnitude slower than the naive approach? #1710

Open ltiao opened 1 year ago

ltiao commented 1 year ago

A common problem is to compute the Cholesky factor of A + u @ u.T, given a PD matrix A (shape n x n) and a rank-1 update vector u (shape n). The obvious and naive way is to directly compute the Choleskly factor of A + u @ u.T, which has complexity O(n^3). However, suppose we already have the Cholesky factor L (shape n x n) of A, then we can use it to compute the Cholesky factor of A + u @ u.T in O(n^2) time.

My understanding is that this is what tfp.math.cholesky_update is supposed to implement. However, a simple benchmark shows that the supposedly optimized approach is about 85 times slower than the obvious naive approach!

The optimized approach using tfp.math.cholesky_update:

$ python -m timeit -s "import numpy as np; import tensorflow as tf; import tensorflow_probability as tfp; rng = np.random.RandomState(42); t = 3; n = 1024; factor = rng.randn(t, n, n); u = rng.randn(t, n); a = tf.linalg.matmul(factor, factor, transpose_b=True); a_scale = tf.linalg.cholesky(a)" -n 3 "tfp.math.cholesky_update(a_scale, u)"
3 loops, best of 5: 1.38 sec per loop

The obvious naive approach:

$ python -m timeit -s "import numpy as np; import tensorflow as tf; import tensorflow_probability as tfp; rng = np.random.RandomState(42); t = 3; n = 1024; factor = rng.randn(t, n, n); u = rng.randn(t, n); a = tf.linalg.matmul(factor, factor, transpose_b=True); a_scale = tf.linalg.cholesky(a)" -n 3 "b = a + tf.linalg.matmul(u[..., tf.newaxis], u[..., tf.newaxis], transpose_b=True); tf.linalg.cholesky(b)"
3 loops, best of 5: 16.3 msec per loop
csuter commented 1 year ago

Adding tf.function(...) wrapper to compile the code to XLA speeds up the cholesky_update for this problem size from the 1.38s you saw to around 40-60ms on my machine (nothing fancy).

Increasing the problem size from 1024 to 2048, the timing for XLA-compiled cholesky_update increases to about 240ms (4-6x, sounds about right for quadratic scaling). The XLA-compiled naive approach for this problem size is up to about 1.8 sec for me.

Main takeaway: always jit compile (or, at least, always try with and without and do whatever is fastest!), and also big O hides constants that matter! Depending on your problem size, the naive method may be better -- but asymptotically, quadratic scaling will beat cubic :)