Closed x-tabdeveloping closed 3 weeks ago
GMM's slowness is in-part addressed by #9.
I have also tried implementing GMM using the EM implementation I listed earlier. It did not work for high dimensionality and a lot of divergences happened. Then I tried explicitly writing out the likelihood in JAX and optimizing it with Optax and JaxOpt, both libraries resulted in divergences in under ten iterations (nan
s all over the place). I also tried using Bayeux with both VI, which gave some weird errors that I had nothing to do with and NUTS implementations, which were too slow to be useful and crashed my computer immediately. I unfortunately think that it is beyond the scope of the project to create and maintain a good JAX implementation of multivariate Gaussian Mixtures.
We could still try ICA and NMF, as those are probably a bit more robust to divergences, but I'm growing a bit skeptical of this idea by the day.
PyMC's VI worked, but it is incredibly slow, so we would just be back at square one basically.
Since this is probably not a good idea and way out of scope for this project, I'm closing the issue
Rationale
Sklearn implementations of certain algorithms are not very scalable, and struggle with larger datasets and number of topics (e.g. GMM). Additionally being able to run inference on GPU and TPU would be a very nice touch.
Proposal
Implement Gaussian Mixtures (maximum likelihood - EM), NMF (coordinate descent) and FastICA in JAX, and pack it in a separate Python package that can serve as an optional dependeny.
Notes
Here are some possible implementations that could be used: