choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Make using JAX (or any accelerator) an Instanced Python class (and toggle) #509

Open Lnaden opened 1 year ago

Lnaden commented 1 year ago

Supersedes #508

This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

~The complicated part of this is casting the actual functions (e.g. gradient, log W nk, precondition u_kn, etc.) have to be generated as static methods, not tied to the actual MBARSolver class. The way JIT for JAX works is it will serialize anything in the function, so any object (i.e. class and self in function) has every one of its methods/parameters serialized as well at compile time, resulting in a massive slowdown if I try to just leave each of the methods as class methods. I get around this with a number of generate_static... methods and replace the actual definitions in the __init__ while still preserving the doc strings and API. In case you're wondering details, see the MBARSolver.__init__ doc string~ Edit: After testing, correct function definitions, and PyTree assignments, this is not a problem and implementation 99% as fast as static methods on average.

Given the drastically different way the mbar_solvers is loaded, and the massive implementation change, I would like to again formally request input from @mikemhenry and @mrshirts about this approach. @jchodera as you suggested the API idea in #508, I'd like to as for your feedback as well given the implementation was not as clean due to the JIT shenanigans.

@invemichele this is the full implementation of the outlined API and features in https://github.com/choderalab/pymbar/issues/496, so any input you have would also be appreciated.

codecov[bot] commented 1 year ago

Codecov Report

Merging #509 (00a9cd7) into master (cfe49fc) will increase coverage by 0.88%. The diff coverage is 89.47%.

Lnaden commented 1 year ago

Note: After some more testing, I think I was instancing the Solver classes everytime a new MBAR was created and that was what was slowing down my testing, that and I had a native Python sum function exposed instead of numpy/jax .sum which also didnt help.

I think I can remove all of the static method generator functions and go back to just clean class methods, and do the pytree registration for good measure just in case. It will clean up the code and maintain speed. Something for me to check and benchmark when I can.

Lnaden commented 1 year ago

After some testing, I found the speed gain from having pure static generated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit).

After testing, here are the results:

Testing the timing of test_protocols test using static-generated methods as a relative baseline: The test is 99% as fast on average with PyTree registration. The test is 95% as fast on average without the PyTree registration.

So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods.

So this version is more pythonic, easier to read, almost as fast as pure static methods, and overall implements the API listed in #496.

This is ready for review

The only outstanding question I have is a naming convention: Do we want to keep the name "accelerator" as I have in most places, or "solver" which I have in a few others, they are just accelerated by the different libraries. The only API concern here is the keyword accelerator=... I added to the MBAR object. Whatever we set, we wont want to change until 5.0, so I want to set it now.

mikemhenry commented 1 year ago

Also you can either fix this in your PR, or merge this one in https://github.com/choderalab/pymbar/pull/510 to fix the RTD builds