motiwari / BanditPAM

BanditPAM C++ implementation and Python package
MIT License
647 stars 38 forks source link

BanditPAM 200x slower than quadratic algorithms at 10k MNIST #175

Open kno10 opened 2 years ago

kno10 commented 2 years ago

I've been comparing BanditPAM to FasterPAM on the first 10k instances of MNIST:

https://colab.research.google.com/drive/1-8fMll3QpsdNV5widn-PrPHa5SGXdAIW?usp=sharing

BanditPAM took 791684.18ms FasterPAM took 3971.87ms, of which 90% are the time needed to compute the pairwise distance matrix

That is 200x slower. I will now try 20k instances.

kno10 commented 2 years ago

First numbers for 20k of MNIST (one run each only, on colab): BanditPAM: 4390447.39ms FasterPAM: 16993.11ms 258x slower, so the gap has widened despite the latter being O(n²). But there is variance here, it may be similar.

motiwari commented 2 years ago

Thanks for the report @kno10 --- I've requested access to the colab notebooks from 2 of my personal email addresses, would you mind granting access so I can investigate?

kno10 commented 2 years ago

Sorry, I didn't click the right colab buttons, the link was meant to be public. It should work now.

motiwari commented 2 years ago

Hi @kno10, thank you for filing this issue and providing an easily reproducible benchmark. It led us to discovering a number of issues that are being worked on:

The first three points above are all addressed in the branch slowdown. The fourth one is implemented as well, but for some reason is not working properly and often returns bad results. I'm going to continue working on this but need to put this on hold for some time due to my other commitments.

Thank you for providing all of these bugs in easily-reproducible ways. Please let me know if you have any other questions or comments while I continue to work on this.

kno10 commented 2 years ago

Distance matrix: Indeed, the distance computations are the main cost, but that is also the baseline any non-quadratic method will need to beat. But even pairwise distance computations can be vectorized (e.g., with AVX - mnist should benefit from this) so I would not expect the benefits to be that huge to really have the matrix, unless you recompute the values very often. In my opinion, the benefits of vectorization at this level are often overestimated (because people tend to look at interpreted code, and "vectorized" then also means calling a compiled library function as opposed to using the Python interpreter, no matter whether the actual underlying code is vectorized or not).

Multithreading: The colab sheet uses n_cpu=1 for FasterPAM, i.e., no parallelism. The wrapper then calls a non-multithreaded implementation; the parallelized version came much later and is not as parallelism-efficient as we would like. I set both to use a single thread for a more fair comparison.

Max Iterations: It converges long before the maximum iteration counter - usually <10 will be enough. I added an "iter" counter to the colab sheet, and it was just 3 iterations on average. I also set it to max_iter=1000 but it will not make a difference. How much quality do you lose in BanditPAM when reducing the iteration limit?

Swap Complexity: Don't use the FastPAM1 version of the trick anymore, the FasterPAM version is both theoretically better (guaranteed, not just expected gains - FasterPAM1 still has a theoretical worst case of O(k)), better understood, and more elegant.