ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.42k stars 936 forks source link

[Feature] Cholesky decomposition #1026

Open awni opened 4 months ago

awni commented 4 months ago

Add for the CPU using Lapack.

For the GPU MPS has a Cholesky which could be a good option to start with (following how we used to do bind MPS matmul):

barronalex commented 1 month ago

I've been using the CPU Cholesky recently and unfortunately it's quite slow for large matrices.

A hybrid CPU/GPU Cholesky with MLX ops is about 2-3x faster than the pure CPU version for N >= 4096:

def cholesky(A: mx.array, block_size: int = 512):
    N = A.shape[-1]
    L = mx.zeros_like(A)
    A = mx.array(A)

    # For numerical stability
    amax = A.abs().max(axis=range(1, A.ndim), keepdims=True)
    A /= amax

    for k in range(0, N, block_size):
        end = min(k + block_size, N)

        L[..., k:end, k:end] = mx.linalg.cholesky(A[..., k:end, k:end], stream=mx.cpu)

        if end < N:
            L_inv = mx.linalg.tri_inv(mx.swapaxes(L[..., k:end, k:end], -1, -2), upper=True, stream=mx.cpu)
            L[..., end:N, k:end] = A[..., end:N, k:end] @ L_inv
            A[..., end:N, end:N] -= mx.matmul(L[..., end:N, k:end], mx.swapaxes(L[..., end:N, k:end], -1, -2))

    L *= mx.sqrt(amax)
    return L

N = 2048: 8ms -> 10ms speed up: 0.8x N = 4096: 81ms -> 28ms speed up: 2.9x N = 8192: 693ms -> 197ms speed up: 3.5x N = 16384: 6010ms -> 2411ms speed up: 2.5x

It seems to be similarly numerical accurate, if anything slightly better than the CPU version when you check $A = LL^{T}$ in float64.

We don't really have a great pattern for ops that force some computation on the CPU and some on the GPU but maybe it's worth merging anyway?

It could be quite hard to write a more performant GPU only kernel for a single matrix since the unblocked cholesky that runs on the CPU above can't easily be parallelized. The batched version could probably be quite a bit faster though.

awni commented 1 month ago

That's pretty awesome that it's faster. A rare example of mixing CPU / GPU speeding things up!

I'm not sure what to do with it. On the one-hand, it's a lot faster which is nice. On the other hand, Implementing this at the op level will kind of break a couple patterns:

I think as a temporary speedup it's fine to add / we probably should.

But it would be useful to know long term what a good plan is for Cholesky and friends (heavy ops which are hard to parallelize in just one or two kernels). Is it feasible that we eventually replace it with our own fast kernel(s)?

The alternative is maybe we should rethink which of those patterns above are worth being consistent about and which are not and maybe come up with a consistent way of working around them.

barronalex commented 1 month ago

I think it's feasible to write a GPU only Cholesky that's at least close to as performant as the above so maybe we don't need to change the pattern.

Given that we'll likely want a batched version for all of these harder to parallelize ops it is tempting to just to keep it consistent and maybe sacrifice a little bit of performance.