choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
235 stars 91 forks source link

pip seems to not be so happy with JAX #501

Open xiki-tempula opened 1 year ago

xiki-tempula commented 1 year ago

I tried to install pymbar via pip in a conda env and there seems to be some JAX problem. The way to reproduce the problem. The conda install is fine though.

Create the env

conda create -n test ipython
conda activate test

Install pymbar

pip install pymbar

Test it

ipython
>>> from pymbar import mbar

Gives

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 from pymbar import mbar

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/__init__.py:32
     29 __email__ = "levi.naden@choderalab.org,jaime.rodriguez-guerra@choderalab.org,michael.shirts@colorado.edu,john.chodera@choderalab.org"
     31 from . import timeseries, testsystems, confidenceintervals
---> 32 from .mbar import MBAR
     33 from .other_estimators import bar, bar_overlap, bar_zero, exp, exp_gauss
     34 from .fes import FES

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/mbar.py:44
     42 import numpy as np
     43 import numpy.linalg as linalg
---> 44 from pymbar import mbar_solvers
     45 from pymbar.utils import (
     46     kln_to_kn,
     47     kn_to_n,
   (...)
     51     check_w_normalized,
     52 )
     54 logger = logging.getLogger(__name__)

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/mbar_solvers.py:16
     14 if force_no_jax:
     15     raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py")
---> 16 from jax.config import config
     18 config.update("jax_enable_x64", True)
     20 from jax.numpy import exp, sum, newaxis, diag, dot, s_

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/__init__.py:35
     30 del _cloud_tpu_init
     32 # Confusingly there are two things named "config": the module and the class.
     33 # We want the exported object to be the class, so we first import the module
     34 # to make sure a later import doesn't overwrite the class.
---> 35 from jax import config as _config_module
     36 del _config_module
     38 # Force early import, allowing use of `jax.core` after importing `jax`.

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/config.py:17
      1 # Copyright 2018 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 
     15 # TODO(phawkins): fix users of this alias and delete this file.
---> 17 from jax._src.config import config  # noqa: F401

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/_src/config.py:24
     21 import threading
     22 from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
---> 24 from jax._src import lib
     25 from jax._src.lib import jax_jit
     26 from jax._src.lib import transfer_guard_lib

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/_src/lib/__init__.py:74
     70   return _jaxlib_version
     72 version_str = jaxlib.version.__version__
     73 version = check_jaxlib_version(
---> 74   jax_version=jax.version.__version__,
     75   jaxlib_version=jaxlib.version.__version__,
     76   minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
     80 # Before importing any C compiled modules from jaxlib, first import the CPU
     81 # feature guard module to verify that jaxlib was compiled in a way that only
     82 # uses instructions that are present on this machine.
     83 import jaxlib.cpu_feature_guard as cpu_feature_guard

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
mikemhenry commented 1 year ago

This should help: https://github.com/choderalab/pymbar/pull/503

I am not sure when we will make another release but it will be optional soon :tm: