poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
468 stars 32 forks source link

[Bug] Both API example notebooks have stopped working on Colab #234

Open nalzok opened 2 years ago

nalzok commented 2 years ago

Describe the bug

Colab runtime freezes during model.fit. It has been running for minutes without any process. The progress bar always shows

Epoch 1/100
196/200 [============================>.] - ETA: 0s - accuracy: 0.7653 - crossentropy_loss: 0.8751 - l2_loss: 0.0364 - loss: 0.9114

When I tried to interrupt the cell execution, Colab promotes The executing code is not responding to interrupts. Would you like to try restarting the runtime? Runtime state including all local variables will be lost.

I then noticed this comment in the High Level API notebook

# For GPU install proper version of your CUDA, following will work in colab:
! pip install --upgrade jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

The runtime still freezes after I uncommented it.

Curiously, the Low Level API contains a different command

# For GPU install proper version of your CUDA, following will work in COLAB:
! pip install --upgrade jax jaxlib==0.1.59+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

After uncommenting it, I got the following error in model.fit

Click to expand ``` --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) [](https://localhost:8080/#) in () ----> 1 from datasets.load import load_dataset 2 import numpy as np 3 4 dataset = load_dataset("mnist") 5 dataset.set_format("np") 18 frames [/usr/local/lib/python3.7/dist-packages/datasets/__init__.py](https://localhost:8080/#) in () 35 del version 36 ---> 37 from .arrow_dataset import Dataset, concatenate_datasets 38 from .arrow_reader import ReadInstruction 39 from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder [/usr/local/lib/python3.7/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in () 52 import pyarrow as pa 53 import pyarrow.compute as pc ---> 54 from huggingface_hub import HfApi, HfFolder 55 from multiprocess import Pool, RLock 56 from requests import HTTPError [/usr/local/lib/python3.7/dist-packages/huggingface_hub/__init__.py](https://localhost:8080/#) in () 68 from .hub_mixin import ModelHubMixin, PyTorchModelHubMixin 69 from .inference_api import InferenceApi ---> 70 from .keras_mixin import ( 71 KerasModelHubMixin, 72 from_pretrained_keras, [/usr/local/lib/python3.7/dist-packages/huggingface_hub/keras_mixin.py](https://localhost:8080/#) in () 25 26 if is_tf_available(): ---> 27 import tensorflow as tf 28 29 [/usr/local/lib/python3.7/dist-packages/tensorflow/__init__.py](https://localhost:8080/#) in () 49 from ._api.v2 import autograph 50 from ._api.v2 import bitwise ---> 51 from ._api.v2 import compat 52 from ._api.v2 import config 53 from ._api.v2 import data [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/__init__.py](https://localhost:8080/#) in () 35 import sys as _sys 36 ---> 37 from . import v1 38 from . import v2 39 from tensorflow.python.compat.compat import forward_compatibility_horizon [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/__init__.py](https://localhost:8080/#) in () 28 from . import autograph 29 from . import bitwise ---> 30 from . import compat 31 from . import config 32 from . import data [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/__init__.py](https://localhost:8080/#) in () 35 import sys as _sys 36 ---> 37 from . import v1 38 from . import v2 39 from tensorflow.python.compat.compat import forward_compatibility_horizon [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/v1/__init__.py](https://localhost:8080/#) in () 45 from tensorflow._api.v2.compat.v1 import layers 46 from tensorflow._api.v2.compat.v1 import linalg ---> 47 from tensorflow._api.v2.compat.v1 import lite 48 from tensorflow._api.v2.compat.v1 import logging 49 from tensorflow._api.v2.compat.v1 import lookup [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/__init__.py](https://localhost:8080/#) in () 7 8 from . import constants ----> 9 from . import experimental 10 from tensorflow.lite.python.lite import Interpreter 11 from tensorflow.lite.python.lite import OpHint [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/__init__.py](https://localhost:8080/#) in () 6 import sys as _sys 7 ----> 8 from . import authoring 9 from tensorflow.lite.python.analyzer import ModelAnalyzer as Analyzer 10 from tensorflow.lite.python.lite import OpResolverType [/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/authoring/__init__.py](https://localhost:8080/#) in () 6 import sys as _sys 7 ----> 8 from tensorflow.lite.python.authoring.authoring import compatible [/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/authoring/authoring.py](https://localhost:8080/#) in () 41 42 # pylint: disable=g-import-not-at-top ---> 43 from tensorflow.lite.python import convert 44 from tensorflow.lite.python import lite 45 from tensorflow.lite.python.metrics import converter_error_data_pb2 [/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/convert.py](https://localhost:8080/#) in () 27 28 from tensorflow.lite.python import lite_constants ---> 29 from tensorflow.lite.python import util 30 from tensorflow.lite.python import wrap_toco 31 from tensorflow.lite.python.convert_phase import Component [/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/util.py](https://localhost:8080/#) in () 49 # pylint: disable=unused-import 50 try: ---> 51 from jax import xla_computation as _xla_computation 52 except ImportError: 53 _xla_computation = None [/usr/local/lib/python3.7/dist-packages/jax/__init__.py](https://localhost:8080/#) in () 33 # We want the exported object to be the class, so we first import the module 34 # to make sure a later import doesn't overwrite the class. ---> 35 from jax import config as _config_module 36 del _config_module 37 [/usr/local/lib/python3.7/dist-packages/jax/config.py](https://localhost:8080/#) in () 15 # TODO(phawkins): fix users of this alias and delete this file. 16 ---> 17 from jax._src.config import config [/usr/local/lib/python3.7/dist-packages/jax/_src/config.py](https://localhost:8080/#) in () 25 import warnings 26 ---> 27 from jax._src import lib 28 from jax._src.lib import jax_jit 29 from jax._src.lib import transfer_guard_lib [/usr/local/lib/python3.7/dist-packages/jax/_src/lib/__init__.py](https://localhost:8080/#) in () 101 version_str = jaxlib.version.__version__ 102 version = check_jaxlib_version( --> 103 jax_version=jax.version.__version__, 104 jaxlib_version=jaxlib.version.__version__, 105 minimum_jaxlib_version=jax.version._minimum_jaxlib_version) AttributeError: module 'jax' has no attribute 'version' ```

I have also tried using Elegy in the notebook I have been working on, and got another error

Click to expand ``` Epoch 1/10 --------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) [](https://localhost:8080/#) in () 7 validation_data=(test_ds['image'], test_ds['label']), ----> 8 shuffle=True 9 ) 17 frames [/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining) 418 inputs=inputs, --> 419 labels=labels, 420 ) [/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in train_on_batch(self, inputs, labels) 616 train_step_fn = self.train_step_fn[self._distributed_strategy] --> 617 logs, model = train_step_fn(self, inputs, labels) 618 [/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs) 475 device=device, backend=backend, name=flat_fun.__name__, --> 476 donated_invars=donated_invars, inline=inline, keep_unused=keep_unused) 477 out_pytree_def = out_tree() [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params) 1764 def bind(self, fun, *args, **params): -> 1765 return call_bind(self, fun, *args, **params) 1766 [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params) 1780 fun_ = lu.annotate(fun_, fun.in_type) -> 1781 outs = top_trace.process_call(primitive, fun_, tracers, params) 1782 return map(full_lower, apply_todos(env_trace_todo(), outs)) [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params) 677 def process_call(self, primitive, f, tracers, params): --> 678 return primitive.impl(f, *tracers, **params) 679 process_map = process_call [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***) 182 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, --> 183 keep_unused, *arg_specs) 184 try: [/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args) 284 else: --> 285 ans = call(fun, *args) 286 cache[key] = (ans, fun.stores) [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs) 230 return lower_xla_callable(fun, device, backend, name, donated_invars, False, --> 231 keep_unused, *arg_specs).compile().unsafe_call 232 [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile(self) 704 self._executable = XlaCompiledComputation.from_xla_computation( --> 705 self.name, self._hlo, self._explicit_args, **self.compile_args) 706 [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in from_xla_computation(name, xla_computation, explicit_args, nreps, device, backend, tuple_args, in_avals, out_avals, effects, kept_var_idx, keepalive) 805 "in {elapsed_time} sec"): --> 806 compiled = compile_or_get_cached(backend, xla_computation, options) 807 buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, compile_options) 767 _dump_ir_to_file(module_name, ir_str) --> 768 return backend_compile(backend, computation, compile_options) 769 [/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs) 205 with TraceAnnotation(name, **decorator_kwargs): --> 206 return func(*args, **kwargs) 207 return wrapper [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options) 712 # separately in Python profiling results --> 713 return backend.compile(built_c, compile_options=options) 714 UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for: %cudnn-conv-bw-filter = (f32[5,5,6,16]{1,0,2,3}, u8[0]{0}) custom-call(f32[256,14,14,6]{2,1,3,0} %multiply.9, f32[256,10,10,16]{2,1,3,0} %multiply.6), window={size=5x5}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(_static_train_step)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(3, 0, 1, 2), rhs_spec=(3, 0, 1, 2), out_spec=(2, 3, 0, 1)) feature_group_count=1 batch_group_count=1 lhs_shape=(256, 14, 14, 6) rhs_shape=(256, 10, 10, 16) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=398}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" Original error: UNIMPLEMENTED: DNN library is not found. To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: XlaRuntimeError Traceback (most recent call last) [](https://localhost:8080/#) in () 6 batch_size=batch_size, 7 validation_data=(test_ds['image'], test_ds['label']), ----> 8 shuffle=True 9 ) [/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining) 417 tmp_logs = self.train_on_batch( 418 inputs=inputs, --> 419 labels=labels, 420 ) 421 tmp_logs.update({"size": data_handler.batch_size}) [/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in train_on_batch(self, inputs, labels) 615 616 train_step_fn = self.train_step_fn[self._distributed_strategy] --> 617 logs, model = train_step_fn(self, inputs, labels) 618 619 if not isinstance(model, type(self)): XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for: %cudnn-conv-bw-filter = (f32[5,5,6,16]{1,0,2,3}, u8[0]{0}) custom-call(f32[256,14,14,6]{2,1,3,0} %multiply.9, f32[256,10,10,16]{2,1,3,0} %multiply.6), window={size=5x5}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(_static_train_step)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(3, 0, 1, 2), rhs_spec=(3, 0, 1, 2), out_spec=(2, 3, 0, 1)) feature_group_count=1 batch_group_count=1 lhs_shape=(256, 14, 14, 6) rhs_shape=(256, 10, 10, 16) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=398}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" Original error: UNIMPLEMENTED: DNN library is not found. To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. ```

Minimal code to reproduce

https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/high-level-api.ipynb https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/low-level-api.ipynb https://colab.research.google.com/drive/1ZGlTknvwMC8nrrPC_rsSBEGpgcFmVicG?usp=sharing

Expected behavior

Training completes successfully.

Library Info

>>> import elegy
>>> print(elegy.__version__)
0.8.6

Screenshots

Screen Shot 2022-05-16 at 13 36 44

Additional context

I am using a GPU runtime. i.e. Python 3 Google Compute Engine backend (GPU)

murphyk commented 2 years ago

IIUC, it should not be necessary to install jax or jaxlib on colab, since it is built in. See eg this lenet_jax notebook.

nalzok commented 2 years ago

That's true. I was using pip install --upgrade to upgrade them to the latest version, since the default JAX version (v0.3.8 as for now) on Colab doesn't work well with Elegy,

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

[<ipython-input-2-cbb187e7d76c>](https://localhost:8080/#) in <module>()
     11 import treeo as to
     12 import treex as tx
---> 13 import elegy as eg
     14 
     15 from bokeh.resources import INLINE

6 frames

[/usr/local/lib/python3.7/dist-packages/elegy/__init__.py](https://localhost:8080/#) in <module>()
     16 )
     17 
---> 18 from .model.model import Model
     19 from .model.model_base import ModelBase, load
     20 from .model.model_core import (

[/usr/local/lib/python3.7/dist-packages/elegy/model/model.py](https://localhost:8080/#) in <module>()
      9 
     10 from elegy import types, utils
---> 11 from elegy.model.model_base import ModelBase
     12 from elegy.model.model_core import (
     13     GradStepOutput,

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in <module>()
     19 from elegy.callbacks.sigint import SigIntMode
     20 from elegy.data import utils as data_utils
---> 21 from elegy.model.model_core import ModelCore, PredStepOutput, TestStepOutput
     22 
     23 __all__ = ["ModelBase", "load"]

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in <module>()
     14 from elegy import types, utils
     15 
---> 16 from . import utils as model_utils
     17 
     18 try:

[/usr/local/lib/python3.7/dist-packages/elegy/model/utils.py](https://localhost:8080/#) in <module>()
      3 try:
      4     import tensorflow as tf  # type: ignore[import]
----> 5     from jax.experimental import jax2tf  # type: ignore[import]
      6 
      7     def convert_and_save_model(

[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/__init__.py](https://localhost:8080/#) in <module>()
     13 # limitations under the License.
     14 
---> 15 from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
     16                                             split_to_logical_devices, PolyShape)
     17 from jax.experimental.jax2tf.call_tf import call_tf

[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/jax2tf.py](https://localhost:8080/#) in <module>()
   2388                     extra_name_stack="checkpoint")
   2389 
-> 2390 tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
   2391 
   2392 def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:

AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'optimization_barrier'
cgarciae commented 2 years ago

@nalzok thanks for reporting this! These notebooks are tested on CI but sadly testing for colab is a manual process. Will try to give it a go but if you find the fix it would be amazing if you can contribute it back :)

nalzok commented 2 years ago

Yeah, I am willing to help but I cannot figure out how to install a package from GitHub. I just created a fork at https://github.com/nalzok/elegy and tried to install it on Colab with

! pip install --upgrade pip
! pip install git+https://github.com/nalzok/elegy

Then I got datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible..

Full error message (click to expand) ``` Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.1.3) Collecting pip Downloading pip-22.1.1-py3-none-any.whl (2.1 MB) |████████████████████████████████| 2.1 MB 7.9 MB/s Installing collected packages: pip Attempting uninstall: pip Found existing installation: pip 21.1.3 Uninstalling pip-21.1.3: Successfully uninstalled pip-21.1.3 Successfully installed pip-22.1.1 Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting git+https://github.com/nalzok/elegy Cloning https://github.com/nalzok/elegy to /tmp/pip-req-build-fbzdwabs Running command git clone --filter=blob:none --quiet https://github.com/nalzok/elegy /tmp/pip-req-build-fbzdwabs Resolved https://github.com/nalzok/elegy to commit 4709ce8dc9dde3925ce717e2358ce49112e36398 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Collecting treex<0.7.0,>=0.6.5 Downloading treex-0.6.10-py3-none-any.whl (111 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 111.7/111.7 kB 5.8 MB/s eta 0:00:00 Collecting tensorboardx<3.0,>=2.1 Downloading tensorboardX-2.5-py2.py3-none-any.whl (125 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.3/125.3 kB 10.5 MB/s eta 0:00:00 Collecting wandb<0.13.0,>=0.12.10 Downloading wandb-0.12.16-py2.py3-none-any.whl (1.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 36.9 MB/s eta 0:00:00 Collecting cloudpickle<2.0.0,>=1.5.0 Downloading cloudpickle-1.6.0-py3-none-any.whl (23 kB) Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (3.17.3) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (1.15.0) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (1.21.6) Collecting rich<12.0.0,>=11.2.0 Downloading rich-11.2.0-py3-none-any.whl (217 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 217.3/217.3 kB 26.2 MB/s eta 0:00:00 Collecting PyYAML<7.0,>=6.0 Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 596.3/596.3 kB 47.9 MB/s eta 0:00:00 Collecting flax<0.5.0,>=0.4.0 Downloading flax-0.4.2-py3-none-any.whl (186 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 186.4/186.4 kB 21.9 MB/s eta 0:00:00 Collecting treeo<0.0.11,>=0.0.10 Downloading treeo-0.0.10-py3-none-any.whl (17 kB) Collecting einops<0.5.0,>=0.4.0 Downloading einops-0.4.1-py3-none-any.whl (28 kB) Collecting certifi<2022.0.0,>=2021.10.8 Downloading certifi-2021.10.8-py2.py3-none-any.whl (149 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 149.2/149.2 kB 20.7 MB/s eta 0:00:00 Collecting optax<0.2.0,>=0.1.1 Downloading optax-0.1.2-py3-none-any.whl (140 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.7/140.7 kB 21.2 MB/s eta 0:00:00 Collecting pathtools Downloading pathtools-0.1.2.tar.gz (11 kB) Preparing metadata (setup.py) ... done Collecting sentry-sdk>=1.0.0 Downloading sentry_sdk-1.5.12-py2.py3-none-any.whl (145 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 145.3/145.3 kB 17.7 MB/s eta 0:00:00 Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (57.4.0) Collecting setproctitle Downloading setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29 kB) Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (7.1.2) Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.23.0) Collecting shortuuid>=0.5.0 Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB) Collecting docker-pycreds>=0.4.0 Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB) Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.8.2) Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (5.4.8) Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.3) Collecting GitPython>=1.0.0 Downloading GitPython-3.1.27-py3-none-any.whl (181 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 181.2/181.2 kB 22.3 MB/s eta 0:00:00 Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.2.2) Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.0.3) Requirement already satisfied: jax>=0.3 in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.3.8) Requirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (4.2.0) Collecting gitdb<5,>=4.0.1 Downloading gitdb-4.0.9-py3-none-any.whl (63 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.1/63.1 kB 7.2 MB/s eta 0:00:00 Collecting chex>=0.0.4 Downloading chex-0.1.3-py3-none-any.whl (72 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 72.2/72.2 kB 10.8 MB/s eta 0:00:00 Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.0.0) Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.3.7+cuda11.cudnn805) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.10) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (1.24.3) Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich<12.0.0,>=11.2.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (2.6.1) Collecting commonmark<0.10.0,>=0.9.0 Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.1/51.1 kB 7.5 MB/s eta 0:00:00 Collecting colorama<0.5.0,>=0.4.0 Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.1.7) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.11.2) Collecting smmap<6,>=3.0.1 Downloading smmap-5.0.0-py3-none-any.whl (24 kB) Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.4.1) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.3.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (2.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.4.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.11.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.0.9) Building wheels for collected packages: elegy, pathtools Building wheel for elegy (pyproject.toml) ... done Created wheel for elegy: filename=elegy-0.8.6-py3-none-any.whl size=72228 sha256=cbaac711df1e4b92557b49daf5b8f819f6505a86f8e130784b7d66ea1636e41d Stored in directory: /tmp/pip-ephem-wheel-cache-rq_lvte2/wheels/71/e7/f6/574c5a5046b672581176a5d22b710ded1fd1db6715b187d363 Building wheel for pathtools (setup.py) ... done Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=b6b282cad3d6d596fec3fdb9350ef8774617a58d5f05a795781d3df8a9b1d850 Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2 Successfully built elegy pathtools Installing collected packages: pathtools, einops, commonmark, certifi, treeo, smmap, shortuuid, setproctitle, sentry-sdk, PyYAML, docker-pycreds, colorama, cloudpickle, tensorboardx, rich, gitdb, GitPython, chex, wandb, optax, flax, treex, elegy Attempting uninstall: certifi Found existing installation: certifi 2022.5.18.1 Uninstalling certifi-2022.5.18.1: Successfully uninstalled certifi-2022.5.18.1 Attempting uninstall: PyYAML Found existing installation: PyYAML 3.13 Uninstalling PyYAML-3.13: Successfully uninstalled PyYAML-3.13 Attempting uninstall: cloudpickle Found existing installation: cloudpickle 1.3.0 Uninstalling cloudpickle-1.3.0: Successfully uninstalled cloudpickle-1.3.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible. Successfully installed GitPython-3.1.27 PyYAML-6.0 certifi-2021.10.8 chex-0.1.3 cloudpickle-1.6.0 colorama-0.4.4 commonmark-0.9.1 docker-pycreds-0.4.0 einops-0.4.1 elegy-0.8.6 flax-0.4.2 gitdb-4.0.9 optax-0.1.2 pathtools-0.1.2 rich-11.2.0 sentry-sdk-1.5.12 setproctitle-1.2.3 shortuuid-1.0.9 smmap-5.0.0 tensorboardx-2.5 treeo-0.0.10 treex-0.6.10 wandb-0.12.16 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting datasets Downloading datasets-2.2.2-py3-none-any.whl (346 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 346.8/346.8 kB 12.6 MB/s eta 0:00:00 Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (3.2.2) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3) Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0) Collecting aiohttp Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 49.3 MB/s eta 0:00:00 Collecting huggingface-hub<1.0.0,>=0.1.0 Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.4/84.4 kB 11.7 MB/s eta 0:00:00 Collecting dill<0.3.5 Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.9/86.9 kB 8.7 MB/s eta 0:00:00 Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1) Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5) Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0) Collecting xxhash Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.2/212.2 kB 25.7 MB/s eta 0:00:00 Collecting responses<0.19 Downloading responses-0.18.0-py3-none-any.whl (38 kB) Collecting fsspec[http]>=2021.05.0 Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.6/140.6 kB 19.7 MB/s eta 0:00:00 Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.3) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (2.8.2) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (1.4.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (3.0.9) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (0.11.0) Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.7.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.2.0) Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib) (1.15.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2021.10.8) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10) Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 128.0/128.0 kB 17.2 MB/s eta 0:00:00 Collecting async-timeout<5.0,>=4.0.0a3 Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB) Collecting frozenlist>=1.1.1 Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.8/144.8 kB 1.9 MB/s eta 0:00:00 Collecting aiosignal>=1.1.2 Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB) Collecting multidict<7.0,>=4.5 Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 94.8/94.8 kB 13.9 MB/s eta 0:00:00 Collecting asynctest==0.13.0 Downloading asynctest-0.13.0-py3-none-any.whl (26 kB) Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0) Collecting yarl<2.0,>=1.0 Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 271.8/271.8 kB 31.2 MB/s eta 0:00:00 Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1) Installing collected packages: xxhash, urllib3, multidict, fsspec, frozenlist, dill, asynctest, async-timeout, yarl, aiosignal, responses, huggingface-hub, aiohttp, datasets Attempting uninstall: urllib3 Found existing installation: urllib3 1.24.3 Uninstalling urllib3-1.24.3: Successfully uninstalled urllib3-1.24.3 Attempting uninstall: dill Found existing installation: dill 0.3.5.1 Uninstalling dill-0.3.5.1: Successfully uninstalled dill-0.3.5.1 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible. Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 huggingface-hub-0.6.0 multidict-6.0.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ```
jiyuuchc commented 2 years ago

Looked into this a bit, since I was doing some testings on colab.

It seems calling reset_metrics() resulted in a hang for any future calls to any of the JITed model functions. This can be demonstrated by overriding the reset_metrics()

def do_nothing():
  pass

model.reset_metrics = do_nothing

The training will finish after this.