Closed Lnaden closed 1 year ago
There is a comparable and probably more Pythonic way to do this that doesn't involve all the globals
and the init_X
functions, but obscures whats happening in other ways.
All the methods which have the overlapping namespaces could be contained in their own modules and then imported as a single module of methods. E.g. accelerator_operations
would be a module that imports pad
, s_
, jit
, optimize
, etc. and then we cast each module to a dictionary in the main mbar_solvers
library so we can use them like
Hinvg = accelerator_operations["jax"].lstsq(H, g, rcond=-1)[0]
In code we would set something like
accelerator = "jax"
Hinvg = accelerator_operations[accelerator].lstsq(H, g, rcond=-1)[0]
and then whatever MBAR setting is chosen it uses that accelerator. This still has the problem of the module having "state" as it were unless all the functional calls were wrapped to accept an "accelerator" keyword or something. I dont have a notion of how to set that initially for the JIT'd functions though correctly and would have to think about it.
Something to consider.
This is an incredibly naïve question, but is there a way that the accelerators could all subclass an MBARAccelerator
base class with a simple API that exposes the major computations that the accelerator needs to perform?
Each accelerator could be implemented in a separate module (file), and the __init__.py
that controls how the pymbar.accelerators
subpackage is loaded could load all available accelerators and provide a way to retrieve the current "fastest" and available ones.
The complexities with jax
are that the JIT requires some time to create the JITed methods, and that this is currently managed by various decorators at the file level. But these functions could be created and JITed in the
initmethod by calling
jit(...)` directly.
Could this reduce the complexity in allowing each accelerator to simply be a subclass that manages its own specialized code, or fails at initialization with a common exception if the accelerator cannot be used in the current environment?
could all subclass an MBARAccelerator base class with a simple API that exposes the major computations
Yes, we could do that, and that would be a much more Pythonic way to do it that doesn't require global
methods which are very non-standard in most python codes I've seen. The main reason I didn't is because the mbar_solvers
module is functional in its use, not class based.
I had thought of putting all of the actual functional code into a single abstract class, and then each accelerator can implement their own accelerator methods for things like pad
and s_
; and still keep the actual solver code identical, which has been really useful for implementing JAX without having to rewrite a fully unique solution for JAX and non-JAX.
However, doing so requires changing the API of mbar_solvers
to be instance based, which would require bumping the major API version (i.e. pymbar 4.x -> pymbar 5.x) to keep with semantic versioning. We could in theory set default (jax) based methods in the mbar_solvers
module which keeps the API in tact and doesn't require a major version update, which we can then put a DeprecationWarning
in if we really want to.
I'd like to not use this init_X
method if possible, keep the same codepaths, and not break the API all at the same time, so I'm open to ideas.
The complexities with jax are that the JIT requires some time to create the JITed methods...
Yes, its slightly more complicated by the use-case/issue I've been resolving in PR's leading up to this. Having JAX + decorated (or at module load) JIT calls is the 64-bit bitsize setting. #496 has the use case where using JAX for other code with its default 32-bit bitsize and PyMBAR needed 64-bit mode causes other JAX uses to break due to incompatible JAX-based models. #497 and my own tests confirm we cant really get by with 32-bit in PyMBAR. The PR's I have done before this one do stagger the JIT and the setting of 64-bit mode correctly and safely, allowing the loading of pymbar without breaking other JAX uses.
If we can embed the JIT in a class, we can still stagger it. With JAX caching the JIT'd functions, we should be okay.
I think this a good approach to solve our current problem in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.
a good approach to solve our current problem
I think the current use-case (#496) is mitigated with the earlier PR's (#504 and #505), so this is just an extra feature with a bit of future proofing that I don't know if we even will use any time soon if at all, i.e. other accelerators support.
in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.
I do think I can rework this PR to do what @jchodera suggested of making this much more Pythonic with a class-based abstract class system which preserves code paths to avoid duplication. I would then make the mbar_solvers
module try to load a default one like we have now, and then expose the methods to preserve the API so we can keep it as a pymbar 4.{Y+1}. I think I would much rather do that rather than this init_X
+ global
approach I have here if the plan is to scrap it anyways.
How does that sound?
Due to the dramatically different approach I went with to make this much more-pythonic, I am closing this in favor of a new PR to discuss there. I'll migrate relevant discussions from here to it.
This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.
This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.
init_accelerator
method with matching name (i.e.init_numpy
orinit_jax
).mbar_solvers
namespace are set through theglobal
word of Python in theinit_X
method and therefore are cast up to the fullmbar_solvers
namespace.mbar_solvers
module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries.Given the drastically different way the
mbar_solver
library is loaded relative to most other Python codes, I would like to formally request input from @mikemhenry and @mrshirts about this approach, as well as anyone else who wants to comment.@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.
A possible future task is to split out all the import logic into a different submodule/folder just to make the
mbar_solvers
module simpler to read, especially if we get more accelerators in the future