Open Lnaden opened 1 year ago
Note: After some more testing, I think I was instancing the Solver classes everytime a new MBAR was created and that was what was slowing down my testing, that and I had a native Python sum
function exposed instead of numpy/jax .sum
which also didnt help.
I think I can remove all of the static method generator functions and go back to just clean class methods, and do the pytree registration for good measure just in case. It will clean up the code and maintain speed. Something for me to check and benchmark when I can.
After some testing, I found the speed gain from having pure static generated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit).
After testing, here are the results:
Testing the timing of test_protocols test using static-generated methods as a relative baseline: The test is 99% as fast on average with PyTree registration. The test is 95% as fast on average without the PyTree registration.
So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods.
So this version is more pythonic, easier to read, almost as fast as pure static methods, and overall implements the API listed in #496.
This is ready for review
The only outstanding question I have is a naming convention: Do we want to keep the name "accelerator" as I have in most places, or "solver" which I have in a few others, they are just accelerated by the different libraries. The only API concern here is the keyword accelerator=...
I added to the MBAR
object. Whatever we set, we wont want to change until 5.0, so I want to set it now.
Also you can either fix this in your PR, or merge this one in https://github.com/choderalab/pymbar/pull/510 to fix the RTD builds
Supersedes #508
This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.
This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.
mbar_solvers
file has now been moved to its own module.MBARSolver
general class as was suggested in #508 by @jchoderambar_solvers
import, so any attempt toimport mbar_solvers
orfrom mbar_solvers import X,Y,Z,*
behave identically to how the main branch currently is. Thus the API is preserved and can keep the 4.y.z version scheme.~The complicated part of this is casting the actual functions (e.g. gradient, log W nk, precondition u_kn, etc.) have to be generated as static methods, not tied to the actual
MBARSolver
class. The way JIT for JAX works is it will serialize anything in the function, so anyobject
(i.e.class
andself
in function) has every one of its methods/parameters serialized as well at compile time, resulting in a massive slowdown if I try to just leave each of the methods as class methods. I get around this with a number ofgenerate_static...
methods and replace the actual definitions in the__init__
while still preserving the doc strings and API. In case you're wondering details, see theMBARSolver.__init__
doc string~ Edit: After testing, correct function definitions, and PyTree assignments, this is not a problem and implementation 99% as fast as static methods on average.Given the drastically different way the
mbar_solvers
is loaded, and the massive implementation change, I would like to again formally request input from @mikemhenry and @mrshirts about this approach. @jchodera as you suggested the API idea in #508, I'd like to as for your feedback as well given the implementation was not as clean due to the JIT shenanigans.@invemichele this is the full implementation of the outlined API and features in https://github.com/choderalab/pymbar/issues/496, so any input you have would also be appreciated.