Marcello-Sega / pytim

a python package for the interfacial analysis of molecular simulations
https://marcello-sega.github.io/pytim/
GNU General Public License v3.0
79 stars 34 forks source link

`evaluate_pbc_fast()` is slow #441

Open a-ws-m opened 3 days ago

a-ws-m commented 3 days ago

I've been using WillardChandler a lot recently, but it recently got to the point that my analysis code would take several hours to run for a series of long simulations. I've attached the cProfile output for the original code (running on small slices of my trajectories). You can see that each execution of evaluate_pbc_fast() takes an average of 2.8 seconds.

JAX has a very fast gaussian_kde() implementation. To make full use of their evaluate() code, instead of using the minimum image convention, we can generate a supercell (3x3x3) of a few periodic images and use these as the dataset.

With this change, each call to gaussian_kde() takes around a millisecond. The code runtime goes from 384s to 15s. Using the default parameters, this results in a slightly larger interface being drawn, so you might need to change the density cutoff.

I'll submit a PR with this change so you can see the difference.

profiles.zip

Marcello-Sega commented 2 days ago

Hi, and thanks for sharing this, this sounds very promising!

I recall there being a tradeoff between using KDTree and scipy.stats.gaussian_kde in earlier versions of the code, depending on whether there were more grid nodes or atoms. If I remember correctly, the older code allowed switching between these methods, possibly guided by heuristics. Given the performance improvements and since your patch doesn’t change the default behavior unless jax is installed, I agree it’s worth integrating.

That said, I’m a little concerned about your remark:

Using the default parameters, this results in a slightly larger interface being drawn.

Ideally, the results should be identical regardless of implementation. If there are discrepancies, it’s worth investigating whether the issue lies in this implementation or the existing one. Could you please help me looking into this to ensure consistency?

Looking forward to reviewing your PR!

Marcello-Sega commented 2 days ago

So, I did some test with your version, modifying a bit the code to be able to switch between the two implementations and this is what I get with the WATER_GRO testcase:

   ...: import MDAnalysis as mda
   ...: 
   ...: from pytim.datafiles import WATER_GRO
   ...: from pytim.gaussian_kde_pbc import gaussian_kde_pbc
   ...: u = mda.Universe(WATER_GRO)
   ...: mesh = 2
   ...: ngrid, spacing = pytim.utilities.compute_compatible_mesh_params(mesh, box=u.dimensions[:3])
   ...: print('ngrid=',ngrid,'spacing=',spacing)
   ...: grid = pytim.utilities.generate_grid_in_box(box=u.dimensions[:3], npoints=ngrid, order='xyz')
   ...: print('grid computed')
   ...: 
   ...: kernel = gaussian_kde_pbc(u.select_atoms('name OW').positions, box=u.dimensions[:3], sigma=2.0, use_jax=False)
   ...: print('kernel inited')
   ...: 
   ...: %timeit -n 1 -r 1 density_field = kernel.evaluate(grid)
   ...: 
   ...: kernelJax = gaussian_kde_pbc(u.select_atoms('name OW').positions, box=u.dimensions[:3], sigma=2.0, use_jax=True)
   ...: print('Jax kernel inited')
   ...: %timeit -n 1 -r 1 density_fieldJax = kernelJax.evaluate(grid)
   ...: #
ngrid= [25 25 75] spacing= [2. 2. 2.]
grid computed
kernel inited
evaluating using pytim implementation
129 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Jax kernel inited
evaluating using jax
10.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Marcello-Sega commented 2 days ago

Can you post a configuration of one of the systems you are using, as well as the parameters that you are using to initialise the WillardChandler class?

The result above is obtained with the code of the branch https://github.com/Marcello-Sega/pytim/tree/faster-kde

a-ws-m commented 2 days ago

JAX uses JIT compilation, so the initial run takes a relatively long time to complete. I think the speed improvements for my code must have come from the fact that I was analysing a whole trajectory, so there were many function calls. Also, the grid size and having a GPU does seem to make a huge difference. I suppose it shouldn't be a drop-in replacement.

As for the difference in the location of the surface, I'm confident that the issue is with the method I'm implementing here. To implement the minimum image convention, you need to change the way that the kernel is evaluated, which you have done in the original code. But to take advantage of JAX's fast evaluate(), I'm just augmenting the particle positions with some more periodic images. This means that each image contributes to the density profile. The only way I can think of to make the methods consistent with each other is to avoid adding the supercell images, but this would only work for phases that span less than half the unit cell, provided they are first centered.

pytim.WillardChandler jax-gpu: 8.62e+01 s (100 iterations, mesh=2.0)
pytim.WillardChandler jax-gpu: 2.19e+02 s (100 iterations, mesh=1.5)

pytim.WillardChandler scipy: 9.32e+01 s (100 iterations, mesh=2.0)
pytim.WillardChandler scipy: 3.30e+02 s (100 iterations, mesh=1.5)

I think this is a useful, alternative calculation for some use cases, like my own. To analyse all of my trajectories, the original code would have taken 8 hours, but this variation took less than 1 hour.

tests.zip