google-research / nasbench

NASBench: A Neural Architecture Search Dataset and Benchmark
Apache License 2.0
682 stars 129 forks source link

Not compatible with tensorflow 2.0 #15

Closed CreeperLin closed 4 years ago

CreeperLin commented 4 years ago

The code works well with tensorflow 1.14 but throws the following error when using tf 2.0:

File ".../nasbench/lib/training_time.py", line 130, in class _TimingRunHook(tf.train.SessionRunHook): AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'SessionRunHook'

ultmaster commented 4 years ago

I think I've implemented a version in tf2.0: https://github.com/ultmaster/nasbench/tree/tf2.

It's not a full version (channel compute is not included) but should be a good start. I've trained a few architectures and sometimes achieved even better results than reported.

CreeperLin commented 4 years ago

That's very helpful, thanks a lot! I noticed that the results reported in paper are evaluated across only 500 individual trials for each search alogrithm.

psyhtest commented 3 years ago

FWIW, I've managed to run example.py under TensorFlow 2.3.1 with the following tf -> tf.compat.v1 modifications:

diff --git a/nasbench/api.py b/nasbench/api.py
index 236897f..97173f2 100644
--- a/nasbench/api.py
+++ b/nasbench/api.py
@@ -143,7 +143,7 @@ class NASBench(object):
     # {108} for the smaller dataset with only the 108 epochs.
     self.valid_epochs = set()

-    for serialized_row in tf.python_io.tf_record_iterator(dataset_file):
+    for serialized_row in tf.compat.v1.python_io.tf_record_iterator(dataset_file):
       # Parse the data from the data file.
       module_hash, epochs, raw_adjacency, raw_operations, raw_metrics = (
           json.loads(serialized_row.decode('utf-8')))
diff --git a/nasbench/lib/evaluate.py b/nasbench/lib/evaluate.py
index b8cbf2c..3c38e82 100644
--- a/nasbench/lib/evaluate.py
+++ b/nasbench/lib/evaluate.py
@@ -27,7 +27,7 @@ import numpy as np
 import tensorflow as tf

 VALID_EXCEPTIONS = (
-    tf.train.NanLossDuringTrainingError,  # NaN loss
+    tf.compat.v1.train.NanLossDuringTrainingError,  # NaN loss
     tf.errors.ResourceExhaustedError,     # OOM
     tf.errors.InvalidArgumentError,       # NaN gradient
     tf.errors.DeadlineExceededError,      # Timed out
diff --git a/nasbench/lib/training_time.py b/nasbench/lib/training_time.py
index 691d4ec..56dd1da 100644
--- a/nasbench/lib/training_time.py
+++ b/nasbench/lib/training_time.py
@@ -127,7 +127,7 @@ _TimingVars = collections.namedtuple(  # pylint: disable=g-bad-name
     ])

-class _TimingRunHook(tf.train.SessionRunHook):
+class _TimingRunHook(tf.compat.v1.train.SessionRunHook):
   """Hook to stop the training after a certain amount of time."""

   def __init__(self, max_train_secs=None):
@@ -171,7 +171,7 @@ class _TimingRunHook(tf.train.SessionRunHook):
       run_context.request_stop()

-class _TimingSaverListener(tf.train.CheckpointSaverListener):
+class _TimingSaverListener(tf.compat.v1.train.CheckpointSaverListener):
   """Saving listener to store the train time up to the last checkpoint save."""

   def begin(self):
$ git clone https://github.com/google-research/nasbench
$ cd nasbench
$ virtualenv venv
created virtual environment CPython3.8.5.final.0-64 in 179ms
  creator CPython3Posix(dest=/home/anton/projects/nasbench/venv, clear=False, global=False)
  seeder FromAppData(download=False, ipaddr=latest, progress=latest, urllib3=latest, wheel=latest, distro=latest, pkg_resources=latest, retrying=latest, setuptools=latest, chardet=latest, lockfile=latest, pytoml=latest, colorama=latest, pep517=latest, contextlib2=latest, six=latest, packaging=latest, certifi=latest, webencodings=latest, requests=latest, appdirs=latest, pip=latest, pyparsing=latest, msgpack=latest, idna=latest, html5lib=latest, distlib=latest, CacheControl=latest, via=copy, app_data_dir=/home/anton/.local/share/virtualenv/seed-app-data/v1.0.1.debian)
  activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
$ source venv/bin/activate
(venv) $ pip install -e .
  Running setup.py develop for nasbench
Successfully installed nasbench tensorboard-2.4.0 tensorflow-2.3.1 tensorflow-estimator-2.3.0