choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Make using JAX (or any accelerator) a toggle #508

Closed Lnaden closed 1 year ago

Lnaden commented 1 year ago

This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

Given the drastically different way the mbar_solver library is loaded relative to most other Python codes, I would like to formally request input from @mikemhenry and @mrshirts about this approach, as well as anyone else who wants to comment.

@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.

A possible future task is to split out all the import logic into a different submodule/folder just to make the mbar_solvers module simpler to read, especially if we get more accelerators in the future

codecov[bot] commented 1 year ago

Codecov Report

Merging #508 (511d867) into master (cfe49fc) will increase coverage by 0.01%. The diff coverage is 98.30%.

Lnaden commented 1 year ago

There is a comparable and probably more Pythonic way to do this that doesn't involve all the globals and the init_X functions, but obscures whats happening in other ways.

All the methods which have the overlapping namespaces could be contained in their own modules and then imported as a single module of methods. E.g. accelerator_operations would be a module that imports pad, s_, jit, optimize, etc. and then we cast each module to a dictionary in the main mbar_solvers library so we can use them like

Hinvg = accelerator_operations["jax"].lstsq(H, g, rcond=-1)[0]

In code we would set something like

accelerator = "jax"
Hinvg = accelerator_operations[accelerator].lstsq(H, g, rcond=-1)[0]

and then whatever MBAR setting is chosen it uses that accelerator. This still has the problem of the module having "state" as it were unless all the functional calls were wrapped to accept an "accelerator" keyword or something. I dont have a notion of how to set that initially for the JIT'd functions though correctly and would have to think about it.

Something to consider.

jchodera commented 1 year ago

This is an incredibly naïve question, but is there a way that the accelerators could all subclass an MBARAccelerator base class with a simple API that exposes the major computations that the accelerator needs to perform?

Each accelerator could be implemented in a separate module (file), and the __init__.py that controls how the pymbar.accelerators subpackage is loaded could load all available accelerators and provide a way to retrieve the current "fastest" and available ones.

The complexities with jax are that the JIT requires some time to create the JITed methods, and that this is currently managed by various decorators at the file level. But these functions could be created and JITed in theinitmethod by callingjit(...)` directly.

Could this reduce the complexity in allowing each accelerator to simply be a subclass that manages its own specialized code, or fails at initialization with a common exception if the accelerator cannot be used in the current environment?

Lnaden commented 1 year ago

could all subclass an MBARAccelerator base class with a simple API that exposes the major computations

Yes, we could do that, and that would be a much more Pythonic way to do it that doesn't require global methods which are very non-standard in most python codes I've seen. The main reason I didn't is because the mbar_solvers module is functional in its use, not class based.

I had thought of putting all of the actual functional code into a single abstract class, and then each accelerator can implement their own accelerator methods for things like pad and s_; and still keep the actual solver code identical, which has been really useful for implementing JAX without having to rewrite a fully unique solution for JAX and non-JAX.

However, doing so requires changing the API of mbar_solvers to be instance based, which would require bumping the major API version (i.e. pymbar 4.x -> pymbar 5.x) to keep with semantic versioning. We could in theory set default (jax) based methods in the mbar_solvers module which keeps the API in tact and doesn't require a major version update, which we can then put a DeprecationWarning in if we really want to.

I'd like to not use this init_X method if possible, keep the same codepaths, and not break the API all at the same time, so I'm open to ideas.

The complexities with jax are that the JIT requires some time to create the JITed methods...

Yes, its slightly more complicated by the use-case/issue I've been resolving in PR's leading up to this. Having JAX + decorated (or at module load) JIT calls is the 64-bit bitsize setting. #496 has the use case where using JAX for other code with its default 32-bit bitsize and PyMBAR needed 64-bit mode causes other JAX uses to break due to incompatible JAX-based models. #497 and my own tests confirm we cant really get by with 32-bit in PyMBAR. The PR's I have done before this one do stagger the JIT and the setting of 64-bit mode correctly and safely, allowing the loading of pymbar without breaking other JAX uses.

If we can embed the JIT in a class, we can still stagger it. With JAX caching the JIT'd functions, we should be okay.

mikemhenry commented 1 year ago

I think this a good approach to solve our current problem in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.

Lnaden commented 1 year ago

a good approach to solve our current problem

I think the current use-case (#496) is mitigated with the earlier PR's (#504 and #505), so this is just an extra feature with a bit of future proofing that I don't know if we even will use any time soon if at all, i.e. other accelerators support.

in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.

I do think I can rework this PR to do what @jchodera suggested of making this much more Pythonic with a class-based abstract class system which preserves code paths to avoid duplication. I would then make the mbar_solvers module try to load a default one like we have now, and then expose the methods to preserve the API so we can keep it as a pymbar 4.{Y+1}. I think I would much rather do that rather than this init_X + global approach I have here if the plan is to scrap it anyways.

How does that sound?

Lnaden commented 1 year ago

Due to the dramatically different approach I went with to make this much more-pythonic, I am closing this in favor of a new PR to discuss there. I'll migrate relevant discussions from here to it.