choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Silently changing JAX default config #496

Open invemichele opened 1 year ago

invemichele commented 1 year ago

I run into an error while trying to load a JAX NN model, and it took me a while to realize the problem was caused by import pymbar. Here JAX global default is changed to x64, which was incompatible with my stored model.

I solved by setting force_no_jax = True here, but probably it would be nice to have a warning somewhere about this global config change, or mention it somewhere in the doc.

mikemhenry commented 1 year ago

Thanks for opening this bug report @invemichele :heart: @mrshirts or @Lnaden you both are more familiar with JAX than me, do we need to set 64bit support?

(At first I thought this was about CPU architecture but it is about 64 bit floats https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)

Testing to see what happens in our CI here: https://github.com/choderalab/pymbar/pull/497 if I remove 64 bit floats

If we do need them, we should document this behavior and also print a warning. This will be important because I am not sure what happens if someone has already started jax and then we try and change the config since the docs say:

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

mikemhenry commented 1 year ago

Looks like we do need 64 bit floats (which is what I thought)

@invemichele I can add a warning and improve documentation, is that sufficient? It doesn't look like we can dynamically change Jax global configs after startup unfortunately

invemichele commented 1 year ago

yes, that would be useful info for people that want to use JAX and pymbar in the same script

Lnaden commented 1 year ago

I've been tinkering with this some more, and I don't think there is a reasonable way to expect PyMBAR to operate in 32 bit mode. Implementing setting 32-bit mode is a bit tricky, but I have an implementation. The problem is there is no reliable way to expect useful outputs. You can get them, but its not reliably accurate or converging.

@invemichele (and others) Is a loud warning sufficient for you use case here, or would it be extremely useful to be able to force 32-bit mode for PyMBAR under the assumption that you're not guaranteed converged results (I would also issue a very loud warning about using 32-bit floats from PyMBAR if thats the case).

invemichele commented 1 year ago

To be clear, I am not interested in pymbar using float32. The issue is that by changing the global JAX default, pymbar is in practice incompatible with other JAX code. My notebook with some JAX neural network, stopped working as soon as I imported pymbar and, since the error came from one of my JAX lines, it was not clear to me that the problem was pymbar until I went through its source code.

A warning about the changed JAX global setting would be very useful for debugging. A solution to the issue would be to give the possibility of setting force_no_jax=True somehow from the import instead of having to recompile pymbar. Another possible solution could be to automatically fall back to the non-jax implementation if JAX had already been imported outside of pymbar. This would be a safety-first approach, since according to the documentation jax_enable_x64 should be set at startup and could otherwise create problems.

Lnaden commented 1 year ago

Another possible solution could be to automatically fall back to the non-jax implementation if JAX had already been imported outside of pymbar

That would be a viable option. I think we can expand that to have the safest and most user-controllable approach. From what I can interpret; so long as the JIT'd functions haven't been called yet, we can still set the 64-bit mode. So how does this sound:

Lnaden commented 1 year ago

@invemichele I've got the warning for this in #504. Functionally, the JAX config is not set until right before the first JIT call and will issue this pair of warnings:

On import (if 32-bit JAX):

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

On change to 64-bit mode:

******* JAX 64-bit mode is now on! *******
*     JAX is now set to 64-bit mode!     *
*   This MAY cause problems with other   *
*      uses of JAX in the same code.     *
******************************************

I realize I still haven't set the API call, but the warnings are what I wanted to do for this PR first so I don't break the API in testing on top of changing the import logic.

Lnaden commented 1 year ago

Magic word closed this, my mistake. Not ready to close until the API to toggle JAX is in.

Lnaden commented 1 year ago

In trying to develop the API side of this, I realize this warning doesn't do any real good because the JIT decorators all activate on import before any of the actual functions are called because of how they work. I can disable the jit of functions with a global parameter, but I don't know how to check each function once its called to set the x64 flag and then jit. I need to delay the actions of the decorator until execution. Re-thinking the code. now.

Even though the current merged version doesn't stop the 64-bit setting on import, it will very loudly warn you at least for now.

Lnaden commented 1 year ago

Got a fix in #505. Once in I can can carry this over to the API.