juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

ImportError: cannot import name 'unzip2' from 'jax' #25

Closed asmith26 closed 4 years ago

asmith26 commented 4 years ago

Hi there,

When trying to import jaxnet with import jaxnet, I get the following error:

~/miniconda3/envs/jaxnet/lib/python3.7/site-packages/jaxnet/core.py in <module>
      6 import dill
      7 import jax
----> 8 from jax import lax, random, unzip2, safe_zip, safe_map, partial, raise_to_shaped, tree_flatten, \
      9     tree_unflatten, flatten_fun_nokwargs, jit, curry
     10 from jax.abstract_arrays import ShapedArray

ImportError: cannot import name 'unzip2' from 'jax'

Libraries include:

Python 3.7.6 (conda)
jax                0.1.67
jaxlib             0.1.47
jaxnet             0.2.5

Thanks for this library!

juliuskunze commented 4 years ago

Hi, thanks for raising this. It's now fixed, simply run pip install --upgrade jaxnet.

asmith26 commented 4 years ago

Thanks very much for the quick fix :)