choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
238 stars 92 forks source link

jax initialization issues #488

Open mrshirts opened 1 year ago

mrshirts commented 1 year ago

Split from #487: Pymbar 4.0 uses jax.numpy and jax is not well configured for newer macs(m1) and I tried using it on an intel i7 cpu but it is not getting initialized.

mrshirts commented 1 year ago

@karanv99 can you leave more information here so we can address this?

karanv99 commented 1 year ago

Yup sure,

On MAC M1

For this diagnosis I am running Pymbar4.0.1 in a M1 mac air 2020(macOS Monterey).

I will attach a screenshot of successful installation of pymbar 4.0 for your diagnosis.

image

This is the error that I get when trying to import pymbar 4.0

image

I tried reading about it and went through your source code trying to figure out what is wrong? The most I could come up with by comparing pymbar3.0 and pymbar4.0 source code was that you in 4.0 onwards you started using jax.numpy instead of numpy arrays(probably because of efficiency or smoothnes). And I found jax has issues on apples silicon chips because of the different architecture and is not well configured.

On Inter i7 PC

Chip info

image

Prompt of successful installation of pymbar4.0

image

This is where initializing mbar using pymbar 4.0 gets stuck on the i7 chip.

image

I tried and waited for more than an hour, hopping it would get initialized but it was unsuccessful. I compared mbar_solvers.py between 3.0 and 4.0, in pymbar3.0 adaptive is the default solver while it is hybr in 4.0. Even though the prompt does claim to try out next method, the code gets stuck as mentioned in the image 5.

Disclaimer

I am new to this whole thing and I am trying to figure things out myself, so I might be wrong on many levels which I have mentioned above. If i am doing something silly, I really do apologise.

Thank you!! :)

mrshirts commented 1 year ago

So, there is no installation issue on Intel, which is good. "No GPU/TPU found, falling back to CPU" is exactly what you want to happen.

For pymbar 4.0 on intel, try using the solver 'robust', and setting the output to 'verbose' (see some of the examples), see what is happening, and let me know what is going on. If there is poor overlap in the collected data, then convergence can take some playing around.

For the problems installing in M1, conda installation is the suggested route (on all applications). It could be that the pip install is getting confused. Can you try to install via conda instead? Install anaconda or miniconda, set the channels as described in the documentation, and conda install pymbar there.

Lnaden commented 1 year ago

I don't have an M1 to test against directly, but from what I am reading of other people having the "partially initialized" is their jax and their jaxlib installs are cached from before they had better M1 support. Can you try uninstalling, clearing the cached versions of jax and jaxlib, then reinstall and see what happens?

The other main issue I am seeing is Homebrew jax and jaxlib are a bit more unreliable on M1 than pip and conda.

karanv99 commented 1 year ago

Hello, everyone. I have an update:

I am trying to figure things out how pymbar 4.0 works(like its syntax etc etc) on the intel PC. I will get back once I figure out with it.

As for the M1, as @mrshirts suggested I think it was the issue with the pip install. I installed pymbar4.0 using conda and it installed the required jax libraries and pymbar is getting imported. I am still figuring out how generate_fes() works on pymbar. Once I do that I will get back.

Thank you all for your help!!

Note!!

Could you please check out the output repository for parallel tempering? I think it has not been updated with the new parallel tempering code. It has the same output from the previous repo.