choderalab / chiron

Differentiable Markov Chain Monte Carlo
https://github.com/choderalab/chiron/wiki
MIT License
14 stars 1 forks source link

Parallelism #36

Open chrisiacovella opened 5 months ago

chrisiacovella commented 5 months ago

In this code we will have two different levels of parallelism.

  1. parallelism of calculations with a given simulation (e.g., vector operations)
  2. parallel sampling algorithms (e.g., replica exchange)

The first bit will be automatically handle through use of JAX functions for computation; at this point, I don't think there will be much benefit to having a single simulation address multiple GPUs (either on the same node or across nodes). We'll be probably operating with the approach of, for a single simulation, using a single gpu (although that of course may change; JAX seems to have a lot of utilities for further parallelism of calculations in this regard).

The second could be handled potentially via, say pmap in JAX, as JAX does support multihost computations ( E.g.: https://jax.readthedocs.io/en/latest/multi_process.html) . However, at this point it is not clear if it would be more straightforward to simply use MPI, as it's not so much a multihost computation we are doing, but rather embarrassingly parallel computations that will need to communicate periodically (e.g., to attempt swaps). In some algorithms/cases, we may actually simple run computations entirely separately and combine at a future time (require us not to do anything to the code)...while less "clean" this sometimes is necessary to get faster throughput (more short jobs requested fewer resources often can move through queues faster and less chance of getting killed or work lost in a situation with a queue allowing preemption).

This issue serves mostly as a place holder for future discussion and exploration, as it will impact design choices moving forward after our initial release.