Open awni opened 4 months ago
I've been using the CPU Cholesky recently and unfortunately it's quite slow for large matrices.
A hybrid CPU/GPU Cholesky with MLX ops is about 2-3x faster than the pure CPU version for N >= 4096
:
def cholesky(A: mx.array, block_size: int = 512):
N = A.shape[-1]
L = mx.zeros_like(A)
A = mx.array(A)
# For numerical stability
amax = A.abs().max(axis=range(1, A.ndim), keepdims=True)
A /= amax
for k in range(0, N, block_size):
end = min(k + block_size, N)
L[..., k:end, k:end] = mx.linalg.cholesky(A[..., k:end, k:end], stream=mx.cpu)
if end < N:
L_inv = mx.linalg.tri_inv(mx.swapaxes(L[..., k:end, k:end], -1, -2), upper=True, stream=mx.cpu)
L[..., end:N, k:end] = A[..., end:N, k:end] @ L_inv
A[..., end:N, end:N] -= mx.matmul(L[..., end:N, k:end], mx.swapaxes(L[..., end:N, k:end], -1, -2))
L *= mx.sqrt(amax)
return L
N = 2048: 8ms -> 10ms speed up: 0.8x
N = 4096: 81ms -> 28ms speed up: 2.9x
N = 8192: 693ms -> 197ms speed up: 3.5x
N = 16384: 6010ms -> 2411ms speed up: 2.5x
It seems to be similarly numerical accurate, if anything slightly better than the CPU version when you check $A = LL^{T}$ in float64
.
We don't really have a great pattern for ops that force some computation on the CPU and some on the GPU but maybe it's worth merging anyway?
It could be quite hard to write a more performant GPU only kernel for a single matrix since the unblocked cholesky
that runs on the CPU above can't easily be parallelized. The batched version could probably be quite a bit faster though.
That's pretty awesome that it's faster. A rare example of mixing CPU / GPU speeding things up!
I'm not sure what to do with it. On the one-hand, it's a lot faster which is nice. On the other hand, Implementing this at the op level will kind of break a couple patterns:
I think as a temporary speedup it's fine to add / we probably should.
But it would be useful to know long term what a good plan is for Cholesky and friends (heavy ops which are hard to parallelize in just one or two kernels). Is it feasible that we eventually replace it with our own fast kernel(s)?
The alternative is maybe we should rethink which of those patterns above are worth being consistent about and which are not and maybe come up with a consistent way of working around them.
I think it's feasible to write a GPU only Cholesky that's at least close to as performant as the above so maybe we don't need to change the pattern.
Given that we'll likely want a batched version for all of these harder to parallelize ops it is tempting to just to keep it consistent and maybe sacrifice a little bit of performance.
Add for the CPU using Lapack.
For the GPU MPS has a Cholesky which could be a good option to start with (following how we used to do bind MPS matmul):