x-tabdeveloping / turftopic

Robust and fast topic models with sentence-transformers.
https://x-tabdeveloping.github.io/turftopic/
MIT License
12 stars 4 forks source link

Implement algorithms in JAX #5

Open x-tabdeveloping opened 6 months ago

x-tabdeveloping commented 6 months ago

Rationale

Sklearn implementations of certain algorithms are not very scalable, and struggle with larger datasets and number of topics (e.g. GMM). Additionally being able to run inference on GPU and TPU would be a very nice touch.

Proposal

Implement Gaussian Mixtures (maximum likelihood - EM), NMF (coordinate descent) and FastICA in JAX, and pack it in a separate Python package that can serve as an optional dependeny.

Notes

Here are some possible implementations that could be used:

x-tabdeveloping commented 6 months ago

GMM's slowness is in-part addressed by #9.

x-tabdeveloping commented 6 months ago

I have also tried implementing GMM using the EM implementation I listed earlier. It did not work for high dimensionality and a lot of divergences happened. Then I tried explicitly writing out the likelihood in JAX and optimizing it with Optax and JaxOpt, both libraries resulted in divergences in under ten iterations (nans all over the place). I also tried using Bayeux with both VI, which gave some weird errors that I had nothing to do with and NUTS implementations, which were too slow to be useful and crashed my computer immediately. I unfortunately think that it is beyond the scope of the project to create and maintain a good JAX implementation of multivariate Gaussian Mixtures.

We could still try ICA and NMF, as those are probably a bit more robust to divergences, but I'm growing a bit skeptical of this idea by the day.

x-tabdeveloping commented 6 months ago

PyMC's VI worked, but it is incredibly slow, so we would just be back at square one basically.