Closed slackner closed 7 years ago
Thanks @slackner ! Current version fixes the bug. Please let me know if you have any other issues. Haven't had a chance to test everything on tensorflow so thanks for finding the bug.
Unpacking updates is much trickier on tensorflow than theano. I updated the original unpack_assignment to support assignment ops, assign_add, tuples, etc. If you compare these two files you'll see what I mean.
https://github.com/bstriner/keras_adversarial/blob/master/keras_adversarial/backend/tensorflow_backend.py https://github.com/bstriner/keras_adversarial/blob/master/keras_adversarial/backend/theano_backend.py
As a side note, K.function supports tuples and will wrap with tf.assign on the fly. If you still have issues, try importing this monkeypatch which will make K.update for tensorflow work the same as theano. May have slight performance implications but everything should work correctly.
Cheers, Ben
Thanks a lot for the quick update. Everything seems to work fine now when using the tensorflow backend (tested with latest commit e661d6fc57ae98a2f31feb7048d2b1b699314a84) and I'm going to close this bug report.
The monkeypatch method is also a good idea and I don't think there would be any significant performance difference. A small disadvantage of your current implementation of assign_moving_average is that the keyword option zero_debias is not supported, but that shouldn't really a problem in most cases. Also, the current implementation of map_params would not be compatible with the monkeypatch, but luckily its not used anywhere. ;)
Best regards, Sebastian
Running example_gan.py from the latest commit f87ace5125efa360a4fa4fc4b13a297ce81e390c leads to the following error when using the tensorflow backend:
The problem is visible since 67e2c07c528cbadc68a5c9c07a62d82c6fb2fe21, where batch_norm_mode was changed. The code causing the problems was already added earlier in 1a4cd7015381e12357539a865dcf27cc3a7e1ba0. The newly added function merge_updates assumes that the updates are tuples, but for Tensorflow they are operations. Something like the following change should fix it (similar to how its done in unrolled_optimizer.py):
Tested with tensorflow 0.12.1 and keras 1.2.0. If you need any additional information, please let me know.