choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Optional JAX dependency #500

Closed Lnaden closed 1 year ago

Lnaden commented 1 year ago

@SimonBoothroyd on the conda-forge feedstock has requested a pymbar where JAX is optional for a leaner install where the full MBAR solver is not needed (for their use case). This can be done with no modifications to the code, only the build, since all the JAX calls are optional and checked at runtime; i think.

This would require a Conda and Pip solution at the same time.

The Conda-forge/build solution is fairly easy with the outputs directive allowing multiple builds. I've worked on the QCFractal package which does that and have the example here; it matches this use case pretty precisely.

The Pip solution could be solved by making JAX a feature instead of default installed through an Optional Dependency. The naive solution would have jax be an extra_requires keyword, resulting in two commands like:

I personally am against this approach because I think the idea is we want JAX installed by default but would like other opinions as well. I don't know of an extra_requires (setup.py variant) or a project.optional-dependencies (pyproject.toml variant) which blocks/omits primary dependencies. There are other options like having multiple build projects, but I don't want to explore those options if the above is acceptable to everyone else due to increased build/maintenance complexity.

mikemhenry commented 1 year ago

+1 on a pymbar-core or something that doesn't have jax on conda forge, and yes I don't think there is a way to do it with pip without doing a hack like setting PYMBAR_NO_JAX=TRUE as an envar and then picking that up in setup.py, but I don't think you can hack that in with a pyproject.toml

I think I could live with the docs saying "we recomend using conda-forge, if you need to use pip, install it like this pip install pymbar[jax]. We could also print a warning when a user installs it like pip install pymbar and say something like "Warning, installed without jax support, [setps on how to reinstall with jax support]"

Lnaden commented 1 year ago

We could also print a warning

I dont know how to issue a message at all with a pip install, but it would be an easy thing to add a warning/info on Pymbar invocation. I'd be in favor of this.

a hack like setting PYMBAR_NO_JAX=TRUE as an envar

I'd like to avoid hacky solutions when there are solutions, just not for the particular paradigm I suggested (optional thing removes dependencies). And I very much would like to avoid envar anything.

mrshirts commented 1 year ago

I know that alchemlyb is a little leery of jax; @xiki-tempula could you maybe speak up on this and your thoughts? . . .

mrshirts commented 1 year ago

We certainly could have a warning at run that it's downloaded without jax, and might be slower, with information on where to get jax.

xiki-tempula commented 1 year ago

I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?

I think JAX is more of a problem for pip as it is installed every time when one setup a new env.

For conda, this is much less of an issue as conda would only need to install JAX once and mamba is much faster than pip. I guess pymbar-core could be one solution and one could also have it as pymbar=*=*JAX where a build varient installs Jax as conda dependency and pymbar=*=*core which would be another build that don't install JAX. But I feel for conda the JAX is much less of an issue.

SimonBoothroyd commented 1 year ago

If pip is a sticking point (it seems like there are a few several year old feature requests to support this kind of thing!) could we maybe just punt on that for now, and just split on conda?

That way both ecosystems would still have an identical pymbar package available, and then if people don't want jax for now they'll need to either install themselves or get from conda?

Lnaden commented 1 year ago

I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?

It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.

could we maybe just punt on that for now, and just split on conda?

Possibly yes-ish. The main problem is we have dependencies specified in the setup.py file and if we dont very carefully tell conda how to build it, conda-build will yell at us for installing dependencies through pip that were not resolved through conda-build. I think its doable though without too much issue.

I also think that there seems to be several people leaning towards optional or opt-in style JAX based on the comments here, so it may be worth doing both anyways.

xiki-tempula commented 1 year ago

It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.

If this is the case, I don't think there is need for a different version? I'm worried that without JAX, one cannot use mbar. But if it is just a slow issue, Just print a warning when using it without JAX shall be fine. When deploying the package, have both pip and conda without JAX and just document it shall be fine?

mikemhenry commented 1 year ago

Hmmm, I do think that maybe at run time we do just issue a warning if JAX isn't installed and say things could be faster, but then have a pymbar-jax conda-forge package and a pip install pymbar[jax] setup.

I do like it when we make it easy for users to do the right thing, but perhaps in this case since no functionality is lost, a warning is enough?

mrshirts commented 1 year ago

I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).

I don't ever see a continual forking of packages - I think that there might be a "takes more to install is faster" package and a default, but never more than that.

mikemhenry commented 1 year ago

I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).

I would envision the warning being printed anytime we hit a code path in the lib that could be accelerated.

Lnaden commented 1 year ago

I've made #503 to turn JAX into an optional dependency.

A warning is issued every time the import block is called from the mbar_solvers code and should be fairly large and visible, but I cannot say if it would be "obvious" enough. I decided to only issue on import rather than every time we could JIT since I think that would be far too annoying to users. Feedback on this idea is welcome.

I went with the pymbar[jax] and pymbar (no jax) scheme for pip.

For conda-forge, I should be able to handle that fully inside its feedstock outside of this repo. The main question is the naming scheme. How does this sound?

it does create the small conflation point where pymbar is technically a slightly different dependency set on Conda instead of PyPI. But I feel thats okay. Thoughts?

Lnaden commented 1 year ago

The warning looks like this: https://github.com/choderalab/pymbar/actions/runs/5245393269/jobs/9472772645?pr=503#step:6:199

I am going to add a new line at the start so it aligns a bit better.

xiki-tempula commented 1 year ago

For conda-forge, can we do pymbar=*=*core to get the no jax build? pymbar=*=*jax will be the default and gets you the JAX build?

An example would be Gromacs, https://anaconda.org/conda-forge/gromacs/files Where one can do gromacs=*=*cuda* to get the cuda build and gromacs=*=*openmpi* to get the openmpi build.

mrshirts commented 1 year ago

The warning looks fine. I'll defer to others on the best way to install.

Lnaden commented 1 year ago

For conda-forge, can we do...

Yes we could, and we would have to specify a build string more exactly like they do in the gromacs feedstock. I would argue using the package specifier of: {package name}={version}={build selector} like you have for pymbar=*=*core is far less common for most users than just {package name}. I also don't know how to specify an actual default when you have multiple build versions unless the user really knows what they want/have like they might for GROMAS where they'll know if they have gpu/cpu options. The advantage to this is everything is all under one package, e.g. "conda-forge/gromacs", on the cloud website and search.

Since we only have standard MBAR and accelerated MBAR, I think adding the burden of users to know {package name}={version}={build selector} and then what selector they want might be more than we're looking for, especially if we want users to just default to conda install pymbar and get the accelerated version.

The way I am proposing is to have multiple outputs like the feedstock here where we (maintainers) don't have to think about build strings or getting conda-forge to build all the versions through specification, and then the users dont have to think at all unless they specifically don't want JAX. We can even make it so that pymbar requires pymbar-core so the users who don't know there is a lighter weight install will see that and at least have a breadcrumb. The downside to this is it does create two separate packages in the cloud and search rather than one.

If people have strong opinions on either choice, I'm happy to implement either.

xiki-tempula commented 1 year ago

I don't have a strong preference. I'm happy with pymbar-core/pymbar. I kind of like build variant approach but if pymbar-core/pymbar is easy for the user and developers, I don't have objection to that.

SimonBoothroyd commented 1 year ago

Thanks @Lnaden, the warning also LGTM!

How does this sound?

  • Conda pymbar-core maps to PyPI's pymbar (no JAX)
  • Conda pymbar maps to PyPI's pymbar[jax] (yes JAX)

+1 to this approach as it seems to be a reasonably common pattern on conda when omitting chunky dependencies (e.g. matplotlib-base)

xiki-tempula commented 1 year ago

I think the warning looks good as well.

Lnaden commented 1 year ago

@xiki-tempula I thought more about it and there's nothing saying we won't wind up with both in the long run. If we get more/different accelerators as options in the future, having a pymbar-core which has no accelerators and then pymbar=*=*{accelerator}-like syntax might be something that comes up.

xiki-tempula commented 1 year ago

I think this is a really good idea.

mrshirts commented 1 year ago

Though probably my instinct would be to have a no-accelerator version conda version, and a single accelerator version that uses relatively common and not0weird accelerators. We could potentially have alternate accelerators just on pip

Lnaden commented 1 year ago

I've merged the PR, I'm going to keep this issue open until we have the Conda-Forge release cut. I'm also going to resolve some other quick before a new release as separate PRs.

SimonBoothroyd commented 1 year ago

@Lnaden is there anything I can do to help unblock this?

Lnaden commented 1 year ago

I was hoping to get #509 in before this, but that's blocked as I want other people's reviews for it. I think I'll just cut a 4.0.2 for optional JAX and move forward with the other parts later. Let me work on that, I should be able to get something up today at least to PyPI

Lnaden commented 1 year ago

Actually, there is one fragment I want to get in to fix a SciPy >=1.9 bug from #509, then'll release, should be pretty quick.

Lnaden commented 1 year ago

Okay. Pymbar 4.0.2 is on PyPI.

Without JAX: pip install pymbar With JAX: pip install pymbar[jax]

Working through Conda-Forge version now.

Lnaden commented 1 year ago

@SimonBoothroyd The Conda-Forge version has been merged and is propagating out, I'll keep an eye on it and this issue open til I'm sure its fully out there.

That said, I think this issue is resolved for all intent and purpose. If you're using the PyPI version, its out there now and ready to go if you want a lighter version of PyMBAR now.

Lnaden commented 1 year ago

And with that Conda-forge has:

Without JAX: conda install pymbar-core With JAX: conda install pymbar

I think that fully resolves this issue! I;m going to close it but can always reopen if there is a need.

Thank you everyone for your contributions to this!

SimonBoothroyd commented 1 year ago

Amazing - thanks so much @Lnaden !