Open esheldon opened 1 month ago
This appears to be because this line: https://github.com/LLNL/MuyGPyS/blob/0dad6a882048bcf885c59a2a23ce09181b7e67f4/src/MuyGPyS/_src/config.py#L8
I do have jax installed and the JaxConfig does not have define_bool_state
Is this a version compatibility issue?
Downgrading to jax 0.4.24 fixed this
Thanks @esheldon for investigating. There is a known incompatibility with recent versions of JAX in Python >= 3.9 arising from their config objects. We can fix this in a future release, but in the meantime thank you for identifying a compatible version of JAX.