google / compare_gan

Compare GAN code.
Apache License 2.0
1.82k stars 319 forks source link

Not python 3 compatible #10

Closed jsirott closed 5 years ago

jsirott commented 6 years ago

Might want to mention this in the README at some point

rhaps0dy commented 6 years ago

Hi, I've made this python3 compatible for my own use. It didn't take very long.

Here is a diff. It's not fully tested yet.

```diff diff --git a/compare_gan/src/eval_gan_lib.py b/compare_gan/src/eval_gan_lib.py index becba4d..8d2594b 100644 --- a/compare_gan/src/eval_gan_lib.py +++ b/compare_gan/src/eval_gan_lib.py @@ -74,7 +74,7 @@ def GetAllTrainingParams(): for gan_type in SUPPORTED_GANS: for dataset in ["mnist", "fashion-mnist", "cifar10", "celeba"]: p = params.GetParameters(gan_type, "wide") - all_params.update(list(p.keys())) + all_params.update(p.keys()) logging.info("All training parameter exported: %s", sorted(all_params)) return sorted(all_params) diff --git a/compare_gan/src/gan_lib.py b/compare_gan/src/gan_lib.py index 316cc52..ee41e7c 100644 --- a/compare_gan/src/gan_lib.py +++ b/compare_gan/src/gan_lib.py @@ -142,7 +142,7 @@ def create_gan(gan_type, dataset, dataset_content, options, def profile_context(tfprofile_dir): if "enable_tf_profile" in FLAGS and FLAGS.enable_tf_profile: with tf.contrib.tfprof.ProfileContext( - tfprofile_dir, trace_steps=list(range(100, 200, 1)), dump_steps=[200]): + tfprofile_dir, trace_steps=range(100, 200, 1), dump_steps=[200]): yield else: yield diff --git a/compare_gan/src/gans/resnet_architecture_test.py b/compare_gan/src/gans/resnet_architecture_test.py index c08e7b7..dc62261 100644 --- a/compare_gan/src/gans/resnet_architecture_test.py +++ b/compare_gan/src/gans/resnet_architecture_test.py @@ -14,11 +14,11 @@ # limitations under the License. """Tests for Resnet architectures.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function - from compare_gan.src.gans import resnet_architecture as resnet_arch import tensorflow as tf @@ -43,8 +43,8 @@ class ResnetArchitectureTest(tf.test.TestCase): def testResnet5GeneratorRuns(self): generator_128 = TestResnet5GeneratorShape(128) generator_64 = TestResnet5GeneratorShape(64) - self.assertEqual(generator_128[0], generator_128[1]) - self.assertEqual(generator_64[0], generator_64[1]) + self.assertEquals(generator_128[0], generator_128[1]) + self.assertEquals(generator_64[0], generator_64[1]) def testResnet5DiscriminatorRuns(self): config = tf.ConfigProto(allow_soft_placement=True) @@ -57,7 +57,7 @@ class ResnetArchitectureTest(tf.test.TestCase): reuse=False) tf.global_variables_initializer().run() output = sess.run([out]) - self.assertEqual(output[0].shape, (batch_size, 1)) + self.assertEquals(output[0].shape, (batch_size, 1)) def testResnet107GeneratorRuns(self): config = tf.ConfigProto(allow_soft_placement=True) @@ -70,7 +70,7 @@ class ResnetArchitectureTest(tf.test.TestCase): noise=z, is_training=True, reuse=False, colors=3) tf.global_variables_initializer().run() output = sess.run([g]) - self.assertEqual(output[0].shape, (batch_size, 128, 128, 3)) + self.assertEquals(output[0].shape, (batch_size, 128, 128, 3)) def testResnet107DiscriminatorRuns(self): config = tf.ConfigProto(allow_soft_placement=True) @@ -83,7 +83,7 @@ class ResnetArchitectureTest(tf.test.TestCase): discriminator_normalization="spectral_norm", reuse=False) tf.global_variables_initializer().run() output = sess.run([out]) - self.assertEqual(output[0].shape, (batch_size, 1)) + self.assertEquals(output[0].shape, (batch_size, 1)) if __name__ == "__main__": tf.test.main() diff --git a/compare_gan/src/generate_tasks_lib.py b/compare_gan/src/generate_tasks_lib.py index 39d0850..7c21b9f 100644 --- a/compare_gan/src/generate_tasks_lib.py +++ b/compare_gan/src/generate_tasks_lib.py @@ -14,6 +14,7 @@ # limitations under the License. """Generate tasks for comparing GANs.""" + from __future__ import absolute_import from __future__ import division @@ -143,7 +144,7 @@ def TestGansWithPenaltyNewDatasets(architecture): def GetDefaultParams(gan_params): """Return the default params for a GAN (=the ones used in the paper).""" ret = {} - for param_name, param_info in gan_params.items(): + for param_name, param_info in gan_params.iteritems(): ret[param_name] = param_info.default return ret diff --git a/compare_gan/src/params.py b/compare_gan/src/params.py index e637226..7ba2ee6 100644 --- a/compare_gan/src/params.py +++ b/compare_gan/src/params.py @@ -19,6 +19,7 @@ We define the default GAN parameters with respect to the datasets and the training hyperparameters. The hyperparameters used by the respective authors are also added to the set. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/compare_gan/src/params_test.py b/compare_gan/src/params_test.py index b32a0ee..34c1ce7 100644 --- a/compare_gan/src/params_test.py +++ b/compare_gan/src/params_test.py @@ -14,11 +14,11 @@ # limitations under the License. """Tests for compare_gan.params.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function - from compare_gan.src import params import tensorflow as tf @@ -27,10 +27,10 @@ class ParamsTest(tf.test.TestCase): def testParameterRanges(self): training_parameters = params.GetParameters("WGAN", "wide") - self.assertEqual(len(list(training_parameters.keys())), 5) + self.assertEqual(len(training_parameters.keys()), 5) training_parameters = params.GetParameters("BEGAN", "wide") - self.assertEqual(len(list(training_parameters.keys())), 6) + self.assertEqual(len(training_parameters.keys()), 6) if __name__ == "__main__": diff --git a/compare_gan/src/simple_task_pb2.py b/compare_gan/src/simple_task_pb2.py index 90a4699..fa58ed5 100644 --- a/compare_gan/src/simple_task_pb2.py +++ b/compare_gan/src/simple_task_pb2.py @@ -28,7 +28,7 @@ from google.protobuf import descriptor_pb2 DESCRIPTOR = _descriptor.FileDescriptor( name='src/simple_task.proto', package='compare_gan', - serialized_pb=b'\n\x15src/simple_task.proto\x12\x0b\x63ompare_gan\"t\n\rTaskDimension\x12\x11\n\tparameter\x18\x01 \x01(\t\x12\x14\n\x0cstring_value\x18\x02 \x01(\t\x12\x11\n\tint_value\x18\x03 \x01(\x05\x12\x13\n\x0b\x66loat_value\x18\x04 \x01(\x02\x12\x12\n\nbool_value\x18\x05 \x01(\x08\"C\n\x04Task\x12\x0b\n\x03num\x18\x01 \x01(\x05\x12.\n\ndimensions\x18\x07 \x03(\x0b\x32\x1a.compare_gan.TaskDimension') + serialized_pb='\n\x15src/simple_task.proto\x12\x0b\x63ompare_gan\"t\n\rTaskDimension\x12\x11\n\tparameter\x18\x01 \x01(\t\x12\x14\n\x0cstring_value\x18\x02 \x01(\t\x12\x11\n\tint_value\x18\x03 \x01(\x05\x12\x13\n\x0b\x66loat_value\x18\x04 \x01(\x02\x12\x12\n\nbool_value\x18\x05 \x01(\x08\"C\n\x04Task\x12\x0b\n\x03num\x18\x01 \x01(\x05\x12.\n\ndimensions\x18\x07 \x03(\x0b\x32\x1a.compare_gan.TaskDimension') @@ -43,14 +43,14 @@ _TASKDIMENSION = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='parameter', full_name='compare_gan.TaskDimension.parameter', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value="", + has_default_value=False, default_value=u"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='string_value', full_name='compare_gan.TaskDimension.string_value', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value="", + has_default_value=False, default_value=u"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), @@ -127,12 +127,14 @@ _TASK.fields_by_name['dimensions'].message_type = _TASKDIMENSION DESCRIPTOR.message_types_by_name['TaskDimension'] = _TASKDIMENSION DESCRIPTOR.message_types_by_name['Task'] = _TASK -class TaskDimension(_message.Message, metaclass=_reflection.GeneratedProtocolMessageType): +class TaskDimension(_message.Message): + __metaclass__ = _reflection.GeneratedProtocolMessageType DESCRIPTOR = _TASKDIMENSION # @@protoc_insertion_point(class_scope:compare_gan.TaskDimension) -class Task(_message.Message, metaclass=_reflection.GeneratedProtocolMessageType): +class Task(_message.Message): + __metaclass__ = _reflection.GeneratedProtocolMessageType DESCRIPTOR = _TASK # @@protoc_insertion_point(class_scope:compare_gan.Task) diff --git a/compare_gan/src/task_utils.py b/compare_gan/src/task_utils.py index 4c45646..b88c155 100644 --- a/compare_gan/src/task_utils.py +++ b/compare_gan/src/task_utils.py @@ -64,7 +64,7 @@ def UnrollCalls(function, kwargs): b x c in [2,4] x [5, 6]. """ res = [] - for key, value in sorted(kwargs.items()): + for key, value in sorted(kwargs.iteritems()): assert not isinstance(key, tuple) if isinstance(value, list): for v in value: @@ -104,7 +104,7 @@ def MakeDimensions(dim_dict, extra_dims=None, base_task=None): dim_dict = copy.copy(dim_dict) dim_dict.update(extra_dims) dim_dict = collections.OrderedDict(sorted(dim_dict.items())) - for key, value in dim_dict.items(): + for key, value in dim_dict.iteritems(): if key in ("_proto", "_prefix"): # We skip the special keys. continue ```

eyaler commented 5 years ago

also need to change cPickle to pickle in gilbo.py

eyaler commented 5 years ago

also in gilbo.py change:

uninitialized = plist(tf.report_uninitialized_variables().eval())

to

uninitialized = plist(tf.report_uninitialized_variables().eval().astype(str))

Marvin182 commented 5 years ago

Thank you for your feedback. We released an update and all code should now be compatible Python 3 compatible. Please open a new ticket if there are still issues.