Closed nalzok closed 2 years ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Please can you refactor this so it does not rely on a TPU. I would either use GPU or CPU. (I was not able to get access to an "accelerator" using my vanilla colab account, so I ran it in CPU model. It completed 4 training epochs in a couple of minutes, which is acceptible :)
Also please fix the shape error shown below.
Also, I see you have the line %pip install -qq flax tensorflow-cpu
. Why do we need to install TF? I think it is built in to colab. And if running locally in juypter notebook I think we can assume the user has already installed it.
This may explain the weird warnings I see in the first cell
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.2+zzzcolab20220527125636 requires keras<2.9,>=2.8.0rc0, but you have keras 2.9.0 which is incompatible.
tensorflow 2.8.2+zzzcolab20220527125636 requires tensorboard<2.9,>=2.8, but you have tensorboard 2.9.1 which is incompatible.
tensorflow 2.8.2+zzzcolab20220527125636 requires tensorflow-estimator<2.9,>=2.8, but you have tensorflow-estimator 2.9.0 which is incompatible.
Sure, I have changed the runtime type to GPU (and it should also work on a CPU runtime). I don't quite understand why there is a shape mismatch; it works perfectly on my Colab. Anyway, I have added a reshape
to ensure the correctness.
Regarding the tensorflow-cpu
dependency, it is to fix a workflow error on GA runners. Basically the module flax.training.checkpoints
implicitly depends on TensorFlow.
E ---------------------------------------------------------------------------
E ModuleNotFoundError Traceback (most recent call last)
E /tmp/ipykernel_4232/994773952.py in <module>
E 13 import flax
E 14 import flax.linenasnn
E ---> 15 from flax.training import train_state, checkpoints
E 16
E 17 try:
E
E ~/miniconda3/envs/py37/lib/python3.7/site-packages/flax/training/checkpoints.py in <module>
E 29 from flax import errors
E 30 from flax import serialization
E ---> 31 from tensorflow.io import gfile # pytype: disable=import-error
E 32
E 33
E
E ModuleNotFoundError: No module named 'tensorflow'
To avoid installing TensorFlow twice on Colab, I used the following pattern
try:
import tensorflow as tf
except ModuleNotFoundError:
# flax.training.checkpoints depends on tensorflow,
# which is only installed by default on Colab
%pip install -qq tensorflow-cpu
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # prevent TensorFlow from using the GPU
Description
Add
ae_mnist_conv_jax.ipynb
as a demo for dimension reduction with autoencoder using CNN for MNIST.Issue
Private communication with @murphyk.
Checklist
Potential problems/Important remarks
https://colab.research.google.com/drive/1YuRksQbwubK1J8Xhy4gbHwQPQ9d2_vpQ?usp=sharing