patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 136 forks source link

Error upon importing equinox #168

Closed SamKG closed 2 years ago

SamKG commented 2 years ago

Hello,

I get an error when importing equinox:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb Cell 1 in <cell line: 8>()
      [6](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=5) import optax
      [7](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=6) import rlax
----> [8](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=7) import equinox as eqx
      [9](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=8) import timeit
     [10](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=9) get_ipython().run_line_magic('env', 'PATH=/usr/local/cuda-11.7/bin')

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/__init__.py:1, in <module>
----> 1 from . import nn
      2 from .filters import (
      3     combine,
      4     filter,
   (...)
     11     split,
     12 )
     13 from .gradf import filter_grad, filter_value_and_grad, gradf, value_and_grad_f

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/nn/__init__.py:1, in <module>
----> 1 from .composed import MLP, Sequential
      2 from .conv import Conv, Conv1d, Conv2d, Conv3d
      3 from .dropout import Dropout

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/nn/composed.py:6, in <module>
      3 import jax.nn as jnn
      4 import jax.random as jrandom
----> 6 from ..module import Module, static_field
      7 from .linear import Linear
     10 class MLP(Module):

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/module.py:8, in <module>
      4 from dataclasses import dataclass, field, fields
      6 import jax
----> 8 from .tree import tree_equal
     11 def static_field(**kwargs):
     12     try:

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/tree.py:6, in <module>
      2 from typing_extensions import get_args
      4 import jax
----> 6 from .custom_types import MoreArrays, PyTree
      9 _sentinel = object()
     12 def tree_at(
     13     where: Callable[[PyTree], Union[Any, Tuple[Any, ...]]],
     14     pytree: PyTree,
     15     replace: Optional[Union[Any, Tuple[Any, ...]]] = _sentinel,
     16     replace_fn: Optional[Callable[[Any], Any]] = _sentinel,
     17 ) -> PyTree:

File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/custom_types.py:17, in <module>
     13 MoreArrays = Union[Array, np.ndarray, jnp.ndarray]
     15 PyTree = Any
---> 17 TreeDef = jaxlib.xla_extension.PyTreeDef

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PyTreeDef'

It seems that jaxlib no longer exports PyTreeDef the same way as of the latest version of JAX (0.3.15 - see https://github.com/google/jax/commit/82e81a0c0a2195545de516d35866ccd84dfac97c). Instead, it is found in jaxlib.xla_extension.pytree.PyTreeDef

SamKG commented 2 years ago

Ah nevermind - it seems that pip somehow installed an extremely old version of equinox instead of the latest. Fixed by running pip install --upgrade equinox to forcefully upgrade the package.