Open nalzok opened 2 years ago
IIUC, it should not be necessary to install jax or jaxlib on colab, since it is built in. See eg this lenet_jax notebook.
That's true. I was using pip install --upgrade
to upgrade them to the latest version, since the default JAX version (v0.3.8 as for now) on Colab doesn't work well with Elegy,
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-2-cbb187e7d76c>](https://localhost:8080/#) in <module>()
11 import treeo as to
12 import treex as tx
---> 13 import elegy as eg
14
15 from bokeh.resources import INLINE
6 frames
[/usr/local/lib/python3.7/dist-packages/elegy/__init__.py](https://localhost:8080/#) in <module>()
16 )
17
---> 18 from .model.model import Model
19 from .model.model_base import ModelBase, load
20 from .model.model_core import (
[/usr/local/lib/python3.7/dist-packages/elegy/model/model.py](https://localhost:8080/#) in <module>()
9
10 from elegy import types, utils
---> 11 from elegy.model.model_base import ModelBase
12 from elegy.model.model_core import (
13 GradStepOutput,
[/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in <module>()
19 from elegy.callbacks.sigint import SigIntMode
20 from elegy.data import utils as data_utils
---> 21 from elegy.model.model_core import ModelCore, PredStepOutput, TestStepOutput
22
23 __all__ = ["ModelBase", "load"]
[/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in <module>()
14 from elegy import types, utils
15
---> 16 from . import utils as model_utils
17
18 try:
[/usr/local/lib/python3.7/dist-packages/elegy/model/utils.py](https://localhost:8080/#) in <module>()
3 try:
4 import tensorflow as tf # type: ignore[import]
----> 5 from jax.experimental import jax2tf # type: ignore[import]
6
7 def convert_and_save_model(
[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/__init__.py](https://localhost:8080/#) in <module>()
13 # limitations under the License.
14
---> 15 from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
16 split_to_logical_devices, PolyShape)
17 from jax.experimental.jax2tf.call_tf import call_tf
[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/jax2tf.py](https://localhost:8080/#) in <module>()
2388 extra_name_stack="checkpoint")
2389
-> 2390 tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
2391
2392 def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'optimization_barrier'
@nalzok thanks for reporting this! These notebooks are tested on CI but sadly testing for colab is a manual process. Will try to give it a go but if you find the fix it would be amazing if you can contribute it back :)
Yeah, I am willing to help but I cannot figure out how to install a package from GitHub. I just created a fork at https://github.com/nalzok/elegy and tried to install it on Colab with
! pip install --upgrade pip
! pip install git+https://github.com/nalzok/elegy
Then I got datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
.
Looked into this a bit, since I was doing some testings on colab.
It seems calling reset_metrics() resulted in a hang for any future calls to any of the JITed model functions. This can be demonstrated by overriding the reset_metrics()
def do_nothing():
pass
model.reset_metrics = do_nothing
The training will finish after this.
Describe the bug
Colab runtime freezes during
model.fit
. It has been running for minutes without any process. The progress bar always showsWhen I tried to interrupt the cell execution, Colab promotes The executing code is not responding to interrupts. Would you like to try restarting the runtime? Runtime state including all local variables will be lost.
I then noticed this comment in the High Level API notebook
The runtime still freezes after I uncommented it.
Curiously, the Low Level API contains a different command
After uncommenting it, I got the following error in
model.fit
Click to expand
``` --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) [I have also tried using Elegy in the notebook I have been working on, and got another error
Click to expand
``` Epoch 1/10 --------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) [Minimal code to reproduce
https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/high-level-api.ipynb https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/low-level-api.ipynb https://colab.research.google.com/drive/1ZGlTknvwMC8nrrPC_rsSBEGpgcFmVicG?usp=sharing
Expected behavior
Training completes successfully.
Library Info
Screenshots
Additional context
I am using a GPU runtime. i.e.
Python 3 Google Compute Engine backend (GPU)