Open Richardjmorton opened 1 year ago
FWIW, I cannot reproduce the problem, it works fine for me, same JAX, same jaxopt. Are you somehow calling scipy directly, differnet than the tutorial?
No. As I mention, I also just opened the notebook on colab and get the same result.
It works fine if I change the solver, e.g., jaxopt.BFGS
Can reproduce: Run in Colab:
jax.__version__
0.3.25
Also,
print(jax.scipy)
<module 'jax.scipy' from '/usr/local/lib/python3.8/dist-packages/jax/scipy/__init__.py'>
On my machine,
print(jax.scipy)
<module 'jax.scipy' from '/home/vincent/.cache/pypoetry/virtualenvs/tinygpenv-7oompMSj-py3.10/lib/python3.10/site-packages/jax/scipy/__init__.py'>
(Yes, I set up a new env for this).
What is your scipy stack?
Hi, turns out my jupyter kernel was picking the conda base environment, rather than the virtual env. This meant that it was using python 3.8, scipy1.7 as default.
Not sure why... anyway, looks like the issue is probably due to scipy version.
Thanks for the help!
Sorry to close/reopen, thought it might be relevant to add something to the tutroial notebooks to install the correct python/scipy version for using on colab.
Changed the issue title to reflect this... feel free to finese if it's not descriptive enough.
I can also fix and open a pull request if desired.
Thanks! Yeah - this looks to be an incompatibility between jaxopt
and older versions of scipy
. It would be ideal if jaxopt
bounded their required version of scipy
(1.0.0 is not sufficient :D), but it's probably easier for us to just update scipy
in all the notebooks that use jaxopt
. If you'd be willing to open a PR that does something like replacing
try:
import jaxopt
except ImportError:
%pip install -q jaxopt
with
try:
import jaxopt
except ImportError:
%pip install -q -U jaxopt scipy
in the notebooks. I'm not too fussed about documenting this here since it's really a jaxopt
issue, rather than a tinygp
one, but if you feel strongly that it should be, feel free to include some comments on that too!
Thanks!
Hi, I am trying to run the code given in the tutorial Fitting a mean function.
However, when I get to the minimisation part:
soln = solver.run(jax.tree_util.tree_map(jnp.asarray, params))
it gives me the following error:
AttributeError: module 'scipy' has no attribute 'optimize'
Any chance you've come across this before?
I am using Jax v0.4.4, jaxopt v0.6.
Thanks.
Edit: I have also tried launching the notebook (from the tinygp webpage) on colab and the same error occurs.