bstriner / keras-adversarial

Keras Generative Adversarial Networks
MIT License
867 stars 231 forks source link

example_gan.py does not work with tensorflow backend #2

Closed slackner closed 7 years ago

slackner commented 7 years ago

Running example_gan.py from the latest commit f87ace5125efa360a4fa4fc4b13a297ce81e390c leads to the following error when using the tensorflow backend:

Traceback (most recent call last):
  File "example_gan.py", line 142, in <module>
    main()
  File "example_gan.py", line 138, in main
    latent_dim=latent_dim)
  File "example_gan.py", line 114, in example_gan
    batch_size=32)
  File "[...]/.local/lib/python2.7/site-packages/Keras-1.2.0-py2.7.egg/keras/engine/training.py", line 1115, in fit
    self._make_train_function()
  File "[...]/projects/keras_adversarial/keras_adversarial/adversarial_model.py", line 148, in _make_train_function
    self.updates,
  File "[...]/projects/keras_adversarial/keras_adversarial/adversarial_model.py", line 122, in updates
    return merge_updates(list(itertools.chain.from_iterable(model.updates for model in self.layers)))
  File "[...]/projects/keras_adversarial/keras_adversarial/adversarial_utils.py", line 133, in merge_updates
    for k, v in updates:
  File "[...]/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 510, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

The problem is visible since 67e2c07c528cbadc68a5c9c07a62d82c6fb2fe21, where batch_norm_mode was changed. The code causing the problems was already added earlier in 1a4cd7015381e12357539a865dcf27cc3a7e1ba0. The newly added function merge_updates assumes that the updates are tuples, but for Tensorflow they are operations. Something like the following change should fix it (similar to how its done in unrolled_optimizer.py):

@@ -126,11 +126,19 @@ def uniform_latent_sampling(latent_shape, low=0.0, high=1.0):
 def n_choice(x, n):
     return x[np.random.choice(x.shape[0], size=n, replace=False)]

+if K.backend() == "tensorflow":
+    def unpack_assignment(a):
+        return a.op.inputs[0], a.op.inputs[1]
+
+else:
+    def unpack_assignment(a):
+        return a

 def merge_updates(updates):
     """Average repeated updates of the same variable"""
     upd = {}
-    for k, v in updates:
+    for kv in updates:
+        k, v = unpack_assignment(kv)
         if k not in upd:
             upd[k] = []
         upd[k].append(v)

Tested with tensorflow 0.12.1 and keras 1.2.0. If you need any additional information, please let me know.

bstriner commented 7 years ago

Thanks @slackner ! Current version fixes the bug. Please let me know if you have any other issues. Haven't had a chance to test everything on tensorflow so thanks for finding the bug.

Unpacking updates is much trickier on tensorflow than theano. I updated the original unpack_assignment to support assignment ops, assign_add, tuples, etc. If you compare these two files you'll see what I mean.

https://github.com/bstriner/keras_adversarial/blob/master/keras_adversarial/backend/tensorflow_backend.py https://github.com/bstriner/keras_adversarial/blob/master/keras_adversarial/backend/theano_backend.py

As a side note, K.function supports tuples and will wrap with tf.assign on the fly. If you still have issues, try importing this monkeypatch which will make K.update for tensorflow work the same as theano. May have slight performance implications but everything should work correctly.

https://github.com/bstriner/keras_adversarial/blob/master/keras_adversarial/backend/tensorflow_monkeypatch.py

Cheers, Ben

slackner commented 7 years ago

Thanks a lot for the quick update. Everything seems to work fine now when using the tensorflow backend (tested with latest commit e661d6fc57ae98a2f31feb7048d2b1b699314a84) and I'm going to close this bug report.

The monkeypatch method is also a good idea and I don't think there would be any significant performance difference. A small disadvantage of your current implementation of assign_moving_average is that the keyword option zero_debias is not supported, but that shouldn't really a problem in most cases. Also, the current implementation of map_params would not be compatible with the monkeypatch, but luckily its not used anywhere. ;)

Best regards, Sebastian