proteneer / timemachine

Differentiate all the things!
Other
138 stars 17 forks source link

Improve performance of HREX swap moves #1294

Closed mcwitt closed 3 months ago

mcwitt commented 3 months ago

Investigations done as part of #1287 showed that the performance of HREX falls off rapidly for larger numbers of windows. This is particularly pronounced in vacuum, where the overhead of the HREX moves is large compared to sampling and computation of the $u_{kl}$ matrix. See the right-hand panel of the following figure:

The rapid dropoff in vacuum performance is largely due to the rapid growth (as $L^3$, according to the heuristic we currently use) of the number of swap attempts per HREX iteration with $L$ windows. E.g. with 48 windows, we do $48^3 \approx 10^5$ swap attempts per iteration. Because a batch of swap attempts is currently implemented as an unoptimized pure-Python for loop, there should be significant room for performance improvement in the implementation.

This PR adds a specialized implementation of a batch of neighbor swap moves that generates random numbers in batches and leverages JAX's jit compilation for performance. Benchmarks show that this implementation is significantly more performant (~50-100x) for large numbers of windows, although the initial compilation does introduce a small amount of overhead with a small number of windows.

HREX MD timings

100 frames, no water sampling. As expected, performance improvement is significant for vacuum and marginal for solvent (where sampling and $u_{kl}$ computation is much more significant)

benchmarks

Profiling permutation moves

With 4 windows, the overhead of JIT compilation is significant:

With 48 windows, optimized version is ~10x faster:

Time to run a single permutation move (consisting of $L^3$ swap attempts), for the current version ("ref") and the version in this PR ("fast"). The time to compute the $u_{kl}$ matrix ("u_kl") is shown for reference. Note: this does not include the initial JIT compilation time for the fast version, which only runs once for a given number of windows and takes about half a second.

timings

Other changes

TODO

maxentile commented 3 months ago

This PR adds a specialized implementation of a batch of neighbor swap moves that leverages JAX's jit compilation for performance. Benchmarks show that this implementation is significantly more performant (~10x) for large numbers of windows, although the initial compilation does introduce a small amount of overhead with a small number of windows.

Nice! And the attached profiling results at K=4 and K=48 are informative.

Could you please also attach to the PR description a plot of total time to draw an approximate sample from p(permutation | xs) as a function of K?

(x-axis: grid of K up to K=48, y-axis: time in seconds)

(y-axis currently expected to be (K^3 * avg_cost_per_swap_attempt) + (K^2 cost of computing u_kl matrix) + (any other per-HREX-iteration overheads), ignoring any one-time overheads such as JIT compilation... (unless JIT compiler needs to be invoked once per HREX iteration))

mcwitt commented 3 months ago

This PR adds a specialized implementation of a batch of neighbor swap moves that leverages JAX's jit compilation for performance. Benchmarks show that this implementation is significantly more performant (~10x) for large numbers of windows, although the initial compilation does introduce a small amount of overhead with a small number of windows.

Nice! And the attached profiling results at K=4 and K=48 are informative.

Could you please also attach to the PR description a plot of total time to draw an approximate sample from p(permutation | xs) as a function of K?

(x-axis: grid of K up to K=48, y-axis: time in seconds)

(y-axis currently expected to be (K^3 * avg_cost_per_swap_attempt) + (K^2 cost of computing u_kl matrix) + (any other per-HREX-iteration overheads), ignoring any one-time overheads such as JIT compilation... (unless JIT compiler needs to be invoked once per HREX iteration))

Good call. I added this plot to the PR description.

mcwitt commented 3 months ago

Generating random indices and uniform noise in batches (https://github.com/proteneer/timemachine/pull/1294/commits/bb9e075985c0830b5d24e14ef76aa3219bbed355) yields another 5-10x performance improvement for permutation moves, for a total of 50-100x improvement relative to master.