lindermanlab / ssm

Bayesian learning and inference for state space models
MIT License
559 stars 197 forks source link

Fix incompatibility with scipy>1.2.0 #109

Closed mmyros closed 4 years ago

mmyros commented 4 years ago

logsumexp was moved to scipy.special

bantin commented 4 years ago

@mmyros Thanks for making this PR! Personally I'm okay with this approach for now. Long term though, the proper thing to do at some point is probably to standardize on using scipy > 1.2.0 and requiring that during install. The slight issue is that changing everything to scipy.special.logsumexp would break existing users installs. Though that might not be a big deal.

@slinderman Thoughts on this?

slinderman commented 4 years ago

Sorry, didn't realize this was still open. I'm hesitant to use scipy instead of autograd.scipy given that some codepaths expect autodiff to be available. I believe that autograd now wraps the appropriate version of logsumexp anyway, so make sure you have the most up to date version of autograd.

jmarkow commented 4 years ago

FYI this appears to be broken on the pip version of ssm.

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-10-9b34101e2221> in <module>
      9 
     10 
---> 11 import ssm
     12 import warnings
     13 import pandas as pd

~/miniconda3/envs/decoding/lib/python3.6/site-packages/ssm/__init__.py in <module>
      1 # Default imports for SSM
      2 
----> 3 from .hmm import *
      4 from .lds import *

~/miniconda3/envs/decoding/lib/python3.6/site-packages/ssm/hmm.py in <module>
      7 
      8 from ssm.optimizers import adam_step, rmsprop_step, sgd_step, convex_combination
----> 9 from ssm.primitives import hmm_normalizer, hmm_expected_states, hmm_filter, hmm_sample, viterbi
     10 from ssm.util import ensure_args_are_lists, ensure_args_not_none, \
     11     ensure_slds_args_not_none, ensure_variational_args_are_lists, \

~/miniconda3/envs/decoding/lib/python3.6/site-packages/ssm/primitives.py in <module>
      2 import autograd.numpy as np
      3 import autograd.numpy.random as npr
----> 4 from autograd.scipy.misc import logsumexp
      5 from autograd.scipy.linalg import cholesky_banded, solve_banded, solveh_banded
      6 from autograd.extend import primitive, defvjp

ImportError: cannot import name 'logsumexp'

Output of pip freeze

appdirs==1.4.4
attrs==19.3.0
autograd==1.3
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
black==19.10b0
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1588608214987/work
bokeh==2.1.1
certifi==2020.6.20
click==7.1.2
cloudpickle==1.4.1
contextvars==2.4
cycler==0.10.0
Cython==0.29.21
dask==2.19.0
dask-jobqueue==0.7.1
decorator==4.4.2
defusedxml==0.6.0
distributed==2.19.0
entrypoints==0.3
flake8==3.8.3
future==0.18.2
h5py==2.10.0
HeapDict==1.0.1
-e git+https://github.com/dattalab/hyphyber.git@c4c345e5373f1aae3d9afac36e69368390b19872#egg=hyphyber
immutables==0.14
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1591451746006/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1590020193992/work/dist/ipykernel-5.3.0-py3-none-any.whl
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1590796900990/work
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.17.0
Jinja2==2.11.2
joblib==0.15.1
jsonschema==3.2.0
jupyter-client==6.1.3
jupyter-core==4.6.3
kiwisolver==1.2.0
MarkupSafe==1.1.1
matplotlib==3.2.2
mccabe==0.6.1
memory-profiler==0.57.0
mistune==0.8.4
msgpack==1.0.0
natsort==7.0.1
nb-black==1.0.7
nbconvert==5.6.1
nbformat==5.0.6
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1588887228799/work
numpy==1.18.5
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1589925210001/work
pandas==1.0.5
pandocfilters==1.4.2
parso==0.7.0
pathspec==0.8.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.1.2
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1590412252446/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1592500439797/work
psutil==5.7.0
ptyprocess==0.6.0
pyaml==20.4.0
pyarrow==0.17.1
pycodestyle==2.6.0
pyflakes==2.2.0
Pygments==2.6.1
pyparsing==2.4.7
pyrsistent==0.16.0
python-dateutil==2.8.1
pytz==2020.1
PyYAML==5.3.1
pyzmq==19.0.1
regex==2020.6.8
-e git+https://github.com/dattalab/reinforcement-analysis.git@73789b2088616433eb3daa3bdf7d10a0640c44ec#egg=rl_analysis
ruamel.yaml==0.16.10
ruamel.yaml.clib==0.2.0
scikit-learn==0.23.1
scikit-optimize==0.7.4
scipy==1.4.1
seaborn==0.10.1
Send2Trash==1.5.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1590081179328/work
sortedcontainers==2.2.2
ssm==0.0.1
tblib==1.6.0
terminado==0.8.3
testpath==0.4.4
threadpoolctl==2.1.0
toml==0.10.1
toolz==0.10.0
tornado==6.0.4
tqdm==4.46.1
traitlets==4.3.3
typed-ast==1.4.1
typing-extensions==3.7.4.2
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1591600393557/work
webencodings==0.5.1
widgetsnbextension==3.5.1
zict==2.0.0
zipp==3.1.0
jmarkow commented 4 years ago

Everything works fine installing from master btw.