dfm / tinygp

The tiniest of Gaussian Process libraries
https://tinygp.readthedocs.io
MIT License
292 stars 24 forks source link

Add python and scipy installation checks to tutorial notebooks #155

Open Richardjmorton opened 1 year ago

Richardjmorton commented 1 year ago

Hi, I am trying to run the code given in the tutorial Fitting a mean function.

However, when I get to the minimisation part: soln = solver.run(jax.tree_util.tree_map(jnp.asarray, params))

it gives me the following error:

AttributeError: module 'scipy' has no attribute 'optimize'

Any chance you've come across this before?

I am using Jax v0.4.4, jaxopt v0.6.

Thanks.

Edit: I have also tried launching the notebook (from the tinygp webpage) on colab and the same error occurs.

jvlatzko commented 1 year ago

FWIW, I cannot reproduce the problem, it works fine for me, same JAX, same jaxopt. Are you somehow calling scipy directly, differnet than the tutorial?

Richardjmorton commented 1 year ago

No. As I mention, I also just opened the notebook on colab and get the same result.

It works fine if I change the solver, e.g., jaxopt.BFGS

jvlatzko commented 1 year ago

Can reproduce: Run in Colab:

jax.__version__ 0.3.25

Also,

print(jax.scipy) <module 'jax.scipy' from '/usr/local/lib/python3.8/dist-packages/jax/scipy/__init__.py'>

On my machine,

print(jax.scipy) <module 'jax.scipy' from '/home/vincent/.cache/pypoetry/virtualenvs/tinygpenv-7oompMSj-py3.10/lib/python3.10/site-packages/jax/scipy/__init__.py'>

(Yes, I set up a new env for this).

What is your scipy stack?

Richardjmorton commented 1 year ago

Hi, turns out my jupyter kernel was picking the conda base environment, rather than the virtual env. This meant that it was using python 3.8, scipy1.7 as default.

Not sure why... anyway, looks like the issue is probably due to scipy version.

Thanks for the help!

Richardjmorton commented 1 year ago

Sorry to close/reopen, thought it might be relevant to add something to the tutroial notebooks to install the correct python/scipy version for using on colab.

Changed the issue title to reflect this... feel free to finese if it's not descriptive enough.

I can also fix and open a pull request if desired.

dfm commented 1 year ago

Thanks! Yeah - this looks to be an incompatibility between jaxopt and older versions of scipy. It would be ideal if jaxopt bounded their required version of scipy (1.0.0 is not sufficient :D), but it's probably easier for us to just update scipy in all the notebooks that use jaxopt. If you'd be willing to open a PR that does something like replacing

try:
    import jaxopt
except ImportError:
    %pip install -q jaxopt

with

try:
    import jaxopt
except ImportError:
    %pip install -q -U jaxopt scipy

in the notebooks. I'm not too fussed about documenting this here since it's really a jaxopt issue, rather than a tinygp one, but if you feel strongly that it should be, feel free to include some comments on that too!

Thanks!