ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.7k stars 893 forks source link

[Feature] Add support mx.random.multivariate_normal() #502

Open tedwards2412 opened 5 months ago

tedwards2412 commented 5 months ago

@NNSSA and I are working on a sampling package for mlx (https://github.com/tedwards2412/samplex) and it would be extremely useful to have these two functions to do more generic sampling. The latter will involve adding more functionality to the core.linalg sub-package. Is this likely to come in a future update? Happy to help if needed!

awni commented 5 months ago

Let's do these as two issues as diag is much easier than multivariate normal. I assume for multivariate normal you need a non-diagonal covariance?

awni commented 5 months ago

Leaving this issue for mx.random.multivariate_normal and created #503 for diag

awni commented 5 months ago

FYI: for multivariate normal we probably 🤔 need matrix inversion e.g. mx.linalg.inv. Which will also probably help with other things.

tedwards2412 commented 5 months ago

Great, thanks! And yes, non-diagonal covariance would be essential for this.

awni commented 5 months ago

Cool package by the way! You should add a little quick start/usage guide (when it's ready for it).

NNSSA commented 5 months ago

Thanks! Will definitely do :-) Related, I think numpy uses a singular value decomposition to compute a multivariate normal.

awni commented 5 months ago

We have a PR out for QR #310. I think SVD and Cholesky would go similarly. The main issue is there are no Metal implementations for most of Lapack so a lot of this will be CPU only until we can get some kernels implemented.

tedwards2412 commented 5 months ago

@awni We finally got round to adding the quick start you recommended on https://github.com/tedwards2412/samplex. The sampling seems to work well with mlx so far. Looking forward to working on this more in the future!

awni commented 5 months ago

That's awesome!! Out of curiosity, could you tell me a bit more about (some) intended uses for the package? I would love to point people to it if you are ok with that and I understand a bit more in what cases you are targeting.

tedwards2412 commented 5 months ago

Overall the goal is to see if we can allow people to quickly run fairly large scale MCMC sampling locally rather than having to run on a cluster. But it's also just a research project for us to see if there are particular sampling algorithms that are substantially better when you can switch between running on CPU and GPU without any overhead; I don't think has been explored at all before.