Closed SamKG closed 2 years ago
Hello,
I get an error when importing equinox:
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb Cell 1 in <cell line: 8>() [6](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=5) import optax [7](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=6) import rlax ----> [8](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=7) import equinox as eqx [9](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=8) import timeit [10](vscode-notebook-cell://ssh-remote%2Bfat2/data/samyakg/Documents/gflownet-llvm-instr/cartpole.ipynb#ch0000000vscode-remote?line=9) get_ipython().run_line_magic('env', 'PATH=/usr/local/cuda-11.7/bin') File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/__init__.py:1, in <module> ----> 1 from . import nn 2 from .filters import ( 3 combine, 4 filter, (...) 11 split, 12 ) 13 from .gradf import filter_grad, filter_value_and_grad, gradf, value_and_grad_f File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/nn/__init__.py:1, in <module> ----> 1 from .composed import MLP, Sequential 2 from .conv import Conv, Conv1d, Conv2d, Conv3d 3 from .dropout import Dropout File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/nn/composed.py:6, in <module> 3 import jax.nn as jnn 4 import jax.random as jrandom ----> 6 from ..module import Module, static_field 7 from .linear import Linear 10 class MLP(Module): File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/module.py:8, in <module> 4 from dataclasses import dataclass, field, fields 6 import jax ----> 8 from .tree import tree_equal 11 def static_field(**kwargs): 12 try: File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/tree.py:6, in <module> 2 from typing_extensions import get_args 4 import jax ----> 6 from .custom_types import MoreArrays, PyTree 9 _sentinel = object() 12 def tree_at( 13 where: Callable[[PyTree], Union[Any, Tuple[Any, ...]]], 14 pytree: PyTree, 15 replace: Optional[Union[Any, Tuple[Any, ...]]] = _sentinel, 16 replace_fn: Optional[Callable[[Any], Any]] = _sentinel, 17 ) -> PyTree: File ~/miniconda3/envs/rl-jax/lib/python3.9/site-packages/equinox/custom_types.py:17, in <module> 13 MoreArrays = Union[Array, np.ndarray, jnp.ndarray] 15 PyTree = Any ---> 17 TreeDef = jaxlib.xla_extension.PyTreeDef AttributeError: module 'jaxlib.xla_extension' has no attribute 'PyTreeDef'
It seems that jaxlib no longer exports PyTreeDef the same way as of the latest version of JAX (0.3.15 - see https://github.com/google/jax/commit/82e81a0c0a2195545de516d35866ccd84dfac97c). Instead, it is found in jaxlib.xla_extension.pytree.PyTreeDef
jaxlib.xla_extension.pytree.PyTreeDef
Ah nevermind - it seems that pip somehow installed an extremely old version of equinox instead of the latest. Fixed by running pip install --upgrade equinox to forcefully upgrade the package.
pip install --upgrade equinox
Hello,
I get an error when importing equinox:
It seems that jaxlib no longer exports PyTreeDef the same way as of the latest version of JAX (0.3.15 - see https://github.com/google/jax/commit/82e81a0c0a2195545de516d35866ccd84dfac97c). Instead, it is found in
jaxlib.xla_extension.pytree.PyTreeDef