Closed Lnaden closed 1 year ago
+1 on a pymbar-core
or something that doesn't have jax on conda forge, and yes I don't think there is a way to do it with pip
without doing a hack like setting PYMBAR_NO_JAX=TRUE
as an envar and then picking that up in setup.py, but I don't think you can hack that in with a pyproject.toml
I think I could live with the docs saying "we recomend using conda-forge, if you need to use pip, install it like this pip install pymbar[jax]
. We could also print a warning when a user installs it like pip install pymbar
and say something like "Warning, installed without jax support, [setps on how to reinstall with jax support]"
We could also print a warning
I dont know how to issue a message at all with a pip install, but it would be an easy thing to add a warning/info on Pymbar invocation. I'd be in favor of this.
a hack like setting PYMBAR_NO_JAX=TRUE as an envar
I'd like to avoid hacky solutions when there are solutions, just not for the particular paradigm I suggested (optional thing removes dependencies). And I very much would like to avoid envar anything.
I know that alchemlyb is a little leery of jax; @xiki-tempula could you maybe speak up on this and your thoughts? . . .
We certainly could have a warning at run that it's downloaded without jax, and might be slower, with information on where to get jax.
I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?
I think JAX is more of a problem for pip as it is installed every time when one setup a new env.
For conda, this is much less of an issue as conda would only need to install JAX once and mamba is much faster than pip.
I guess pymbar-core
could be one solution and one could also have it as pymbar=*=*JAX
where a build varient installs Jax as conda dependency and pymbar=*=*core
which would be another build that don't install JAX. But I feel for conda the JAX is much less of an issue.
If pip
is a sticking point (it seems like there are a few several year old feature requests to support this kind of thing!) could we maybe just punt on that for now, and just split on conda
?
That way both ecosystems would still have an identical pymbar
package available, and then if people don't want jax for now they'll need to either install themselves or get from conda?
I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?
It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.
could we maybe just punt on that for now, and just split on conda?
Possibly yes-ish. The main problem is we have dependencies specified in the setup.py file and if we dont very carefully tell conda how to build it, conda-build will yell at us for installing dependencies through pip that were not resolved through conda-build. I think its doable though without too much issue.
I also think that there seems to be several people leaning towards optional or opt-in style JAX based on the comments here, so it may be worth doing both anyways.
It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.
If this is the case, I don't think there is need for a different version? I'm worried that without JAX, one cannot use mbar. But if it is just a slow issue, Just print a warning when using it without JAX shall be fine. When deploying the package, have both pip and conda without JAX and just document it shall be fine?
Hmmm, I do think that maybe at run time we do just issue a warning if JAX isn't installed and say things could be faster, but then have a pymbar-jax
conda-forge package and a pip install pymbar[jax]
setup.
I do like it when we make it easy for users to do the right thing, but perhaps in this case since no functionality is lost, a warning is enough?
I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).
I don't ever see a continual forking of packages - I think that there might be a "takes more to install is faster" package and a default, but never more than that.
I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).
I would envision the warning being printed anytime we hit a code path in the lib that could be accelerated.
I've made #503 to turn JAX into an optional dependency.
A warning is issued every time the import
block is called from the mbar_solvers
code and should be fairly large and visible, but I cannot say if it would be "obvious" enough. I decided to only issue on import rather than every time we could JIT since I think that would be far too annoying to users. Feedback on this idea is welcome.
I went with the pymbar[jax]
and pymbar
(no jax) scheme for pip
.
For conda-forge, I should be able to handle that fully inside its feedstock outside of this repo. The main question is the naming scheme. How does this sound?
pymbar-core
maps to PyPI's pymbar
(no JAX)pymbar
maps to PyPI's pymbar[jax]
(yes JAX)it does create the small conflation point where pymbar
is technically a slightly different dependency set on Conda instead of PyPI. But I feel thats okay. Thoughts?
The warning looks like this: https://github.com/choderalab/pymbar/actions/runs/5245393269/jobs/9472772645?pr=503#step:6:199
I am going to add a new line at the start so it aligns a bit better.
For conda-forge, can we do
pymbar=*=*core
to get the no jax build?
pymbar=*=*jax
will be the default and gets you the JAX build?
An example would be Gromacs, https://anaconda.org/conda-forge/gromacs/files
Where one can do gromacs=*=*cuda*
to get the cuda build and gromacs=*=*openmpi*
to get the openmpi build.
The warning looks fine. I'll defer to others on the best way to install.
For conda-forge, can we do...
Yes we could, and we would have to specify a build string more exactly like they do in the gromacs feedstock. I would argue using the package specifier of: {package name}={version}={build selector}
like you have for pymbar=*=*core
is far less common for most users than just {package name}
. I also don't know how to specify an actual default when you have multiple build versions unless the user really knows what they want/have like they might for GROMAS where they'll know if they have gpu/cpu options. The advantage to this is everything is all under one package, e.g. "conda-forge/gromacs", on the cloud website and search.
Since we only have standard MBAR and accelerated MBAR, I think adding the burden of users to know {package name}={version}={build selector}
and then what selector they want might be more than we're looking for, especially if we want users to just default to conda install pymbar
and get the accelerated version.
The way I am proposing is to have multiple outputs like the feedstock here where we (maintainers) don't have to think about build strings or getting conda-forge to build all the versions through specification, and then the users dont have to think at all unless they specifically don't want JAX. We can even make it so that pymbar
requires pymbar-core
so the users who don't know there is a lighter weight install will see that and at least have a breadcrumb. The downside to this is it does create two separate packages in the cloud and search rather than one.
If people have strong opinions on either choice, I'm happy to implement either.
I don't have a strong preference. I'm happy with pymbar-core
/pymbar
. I kind of like build variant approach but if pymbar-core
/pymbar
is easy for the user and developers, I don't have objection to that.
Thanks @Lnaden, the warning also LGTM!
How does this sound?
- Conda pymbar-core maps to PyPI's pymbar (no JAX)
- Conda pymbar maps to PyPI's pymbar[jax] (yes JAX)
+1 to this approach as it seems to be a reasonably common pattern on conda when omitting chunky dependencies (e.g. matplotlib-base
)
I think the warning looks good as well.
@xiki-tempula I thought more about it and there's nothing saying we won't wind up with both in the long run. If we get more/different accelerators as options in the future, having a pymbar-core
which has no accelerators and then pymbar=*=*{accelerator}
-like syntax might be something that comes up.
I think this is a really good idea.
Though probably my instinct would be to have a no-accelerator version conda version, and a single accelerator version that uses relatively common and not0weird accelerators. We could potentially have alternate accelerators just on pip
I've merged the PR, I'm going to keep this issue open until we have the Conda-Forge release cut. I'm also going to resolve some other quick before a new release as separate PRs.
@Lnaden is there anything I can do to help unblock this?
I was hoping to get #509 in before this, but that's blocked as I want other people's reviews for it. I think I'll just cut a 4.0.2 for optional JAX and move forward with the other parts later. Let me work on that, I should be able to get something up today at least to PyPI
Actually, there is one fragment I want to get in to fix a SciPy >=1.9 bug from #509, then'll release, should be pretty quick.
Okay. Pymbar 4.0.2 is on PyPI.
Without JAX: pip install pymbar
With JAX: pip install pymbar[jax]
Working through Conda-Forge version now.
@SimonBoothroyd The Conda-Forge version has been merged and is propagating out, I'll keep an eye on it and this issue open til I'm sure its fully out there.
That said, I think this issue is resolved for all intent and purpose. If you're using the PyPI version, its out there now and ready to go if you want a lighter version of PyMBAR now.
And with that Conda-forge has:
Without JAX: conda install pymbar-core With JAX: conda install pymbar
I think that fully resolves this issue! I;m going to close it but can always reopen if there is a need.
Thank you everyone for your contributions to this!
Amazing - thanks so much @Lnaden !
@SimonBoothroyd on the conda-forge feedstock has requested a pymbar where JAX is optional for a leaner install where the full MBAR solver is not needed (for their use case). This can be done with no modifications to the code, only the build, since all the JAX calls are optional and checked at runtime; i think.
This would require a Conda and Pip solution at the same time.
The Conda-forge/build solution is fairly easy with the
outputs
directive allowing multiple builds. I've worked on the QCFractal package which does that and have the example here; it matches this use case pretty precisely.The Pip solution could be solved by making JAX a feature instead of default installed through an Optional Dependency. The naive solution would have
jax
be anextra_requires
keyword, resulting in two commands like:pip install pymbar
pip install pymbar[jax]
both being valid, the later one get JAX.I personally am against this approach because I think the idea is we want JAX installed by default but would like other opinions as well. I don't know of an
extra_requires
(setup.py variant) or aproject.optional-dependencies
(pyproject.toml variant) which blocks/omits primary dependencies. There are other options like having multiple build projects, but I don't want to explore those options if the above is acceptable to everyone else due to increased build/maintenance complexity.