LLNL / MuyGPyS

A fast, pure python implementation of the MuyGPs Gaussian process realization and training algorithm.
Other
25 stars 11 forks source link

numpy version has error checking jax #244

Open esheldon opened 1 month ago

esheldon commented 1 month ago
pip install --upgrade muygpys[hnswlib]
[ins] In [1]: import MuyGPyS
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 import MuyGPyS

File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/__init__.py:12
      8 import importlib.metadata
     10 __version__ = importlib.metadata.version(__package__)
---> 12 from MuyGPyS._src.config import (
     13     config as config,
     14     jax_config as jax_config,
     15     MPI as MPI,
     16 )

File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/_src/config.py:82
     77     config.state.jax_enabled = val
     80 # JAX and GPU states
---> 82 enable_jax = config.define_bool_state(
     83     name="muygpys_jax_enabled",
     84     default=False,
     85     help="Enable use of jax implementations of math functions.",
     86     update_global_hook=_update_jax_global,
     87     update_thread_local_hook=_update_jax_thread_local,
     88 )
     91 def _update_gpu_global(val):
     92     config.state.gpu_enabled = val

AttributeError: 'MuyGPySConfig' object has no attribute 'define_bool_state'
esheldon commented 1 month ago

This appears to be because this line: https://github.com/LLNL/MuyGPyS/blob/0dad6a882048bcf885c59a2a23ce09181b7e67f4/src/MuyGPyS/_src/config.py#L8

I do have jax installed and the JaxConfig does not have define_bool_state

Is this a version compatibility issue?

esheldon commented 1 month ago

Downgrading to jax 0.4.24 fixed this

bwpriest commented 1 month ago

Thanks @esheldon for investigating. There is a known incompatibility with recent versions of JAX in Python >= 3.9 arising from their config objects. We can fix this in a future release, but in the meantime thank you for identifying a compatible version of JAX.