JianGoForIt / YellowFin

auto-tuning momentum SGD optimizer
Apache License 2.0
422 stars 93 forks source link

Added basic Keras support; fixes "None" gradient issue #6

Closed jmhessel closed 7 years ago

jmhessel commented 7 years ago

This PR...

  1. Adds support for "compute gradients," which is needed by keras. To use the YFOptimizer in keras, you can do something like... model.compile(loss='mse', opt=TFOptimizer(YFOptimizer()))
  2. Adds some checks for "None" gradients, which occur in some tf models

Some concerns I still have...

  1. the global_step parameter in apply_gradients. This is a named argument passed by keras to optimizers to help them keep track of the global step count. However, YFOptimizer seems to track this itself. I am not exactly sure how these two should interplay for keras models.
  2. I am not 100% sure that this integration is perfect. It runs and for all of the keras demos I tried (conv nets, basic fully-connected nets, etc.), decreases the loss reasonably when compared to adam. However, there are some warning messages printed that I am unsure of. Also, there may be some theoretical issue with global_step, etc.
jmhessel commented 7 years ago

FYI: This is the error/warning I get from tensorflow:

ERROR:tensorflow:==================================
Object was never used (type <class 'tensorflow.python.framework.ops.Operation'>):
<tf.Operation 'update_hyper/cond/assert_equal/Assert/Assert' type=Assert>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
['File "demo.py", line 76, in <module>\n    validation_data=(x_test, y_test))', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/keras/models.py", line 870, in fit\n    initial_epoch=initial_epoch)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1490, in fit\n    self._make_train_function()', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1014, in _make_train_function\n    self.total_loss)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/keras/optimizers.py", line 618, in get_updates\n    grads, global_step=self.iterations)', 'File "../YellowFin/tuner_utils/yellowfin.py", line 222, in apply_gradients\n    update_hyper_op = self.update_hyper_param()', 'File "../YellowFin/tuner_utils/yellowfin.py", line 190, in update_hyper_param\n    lambda: self._mu_var) )', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func\n    return func(*args, **kwargs)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1814, in cond\n    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1689, in BuildCondBranch\n    original_result = fn()', 'File "../YellowFin/tuner_utils/yellowfin.py", line 189, in <lambda>\n    self._mu = tf.identity(tf.cond(self._do_tune, lambda: self.get_mu_tensor(),', 'File "../YellowFin/tuner_utils/yellowfin.py", line 180, in get_mu_tensor\n    tf.assert_equal(tf.size(root), tf.constant(1) )', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/ops/check_ops.py", line 318, in assert_equal\n    return control_flow_ops.Assert(condition, data, summarize=summarize)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py", line 170, in wrapped\n    return _add_should_use_warning(fn(*args, **kwargs))', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py", line 139, in _add_should_use_warning\n    wrapped = TFShouldUseWarningWrapper(x)', 'File "/Users/jmhessel/miniconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py", line 96, in __init__\n    stack = [s.strip() for s in traceback.format_stack()]']
==================================
JDvorak commented 7 years ago

In a case of a GAN, where two YFOptimizers are dueling, and the discriminator's loss functions linearly increases from 0 over time, I consistently get an error along these lines in a few epochs:

Traceback (most recent call last):
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\ops\script_ops.py", line 82, in __call__
    ret = func(*args)
  File "C:\Miniconda3\lib\site-packages\numpy\lib\polynomial.py", line 229, in roots
    roots = eigvals(A)
  File "C:\Miniconda3\lib\site-packages\numpy\linalg\linalg.py", line 903, in eigvals
    _assertFinite(a)
  File "C:\Miniconda3\lib\site-packages\numpy\linalg\linalg.py", line 217, in _assertFinite
    raise LinAlgError("Array must not contain infs or NaNs")
numpy.linalg.linalg.LinAlgError: Array must not contain infs or NaNs
2017-07-03 12:43:43.093655: W c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\core\framework\op_kernel.cc:1152] Int
ernal: Failed to run py callback pyfunc_1: see error log.
Traceback (most recent call last):
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1039, in _do_call
    return fn(*args)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1021, in _run_fn
    status, run_metadata)
  File "C:\Miniconda3\lib\contextlib.py", line 66, in __exit__
    next(self.gen)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InternalError: Failed to run py callback pyfunc_1: see error log.
         [[Node: update_hyper_1/cond/PyFuncStateless = PyFuncStateless[Tin=[DT_FLOAT], Tout=[DT_COMPLEX64], token="pyfunc_1", _device="/job:l
ocalhost/replica:0/task:0/cpu:0"](update_hyper_1/cond/ScatterUpdate/_875)]]
         [[Node: update_hyper_1/cond/Gather_1/_883 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send
_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1626_update_hyper_1/cond/Gather_1", tensor_type
=DT_COMPLEX64, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 57, in <module>
    main()
  File "main.py", line 52, in main
    model.train(args.flag)
  File "xxx\src\operator\op_EGAN.py", line 286, in train
    _, loss_q, summary = self.sess.run(disc_class_q_opts, feed_dict=feed_dict)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 778, in run
    run_metadata_ptr)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 982, in _run
    feed_dict_string, options, run_metadata)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1032, in _do_run
    target_list, options, run_metadata)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1052, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Failed to run py callback pyfunc_1: see error log.
         [[Node: update_hyper_1/cond/PyFuncStateless = PyFuncStateless[Tin=[DT_FLOAT], Tout=[DT_COMPLEX64], token="pyfunc_1", _device="/job:l
ocalhost/replica:0/task:0/cpu:0"](update_hyper_1/cond/ScatterUpdate/_875)]]
         [[Node: update_hyper_1/cond/Gather_1/_883 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send
_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1626_update_hyper_1/cond/Gather_1", tensor_type
=DT_COMPLEX64, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

Caused by op 'update_hyper_1/cond/PyFuncStateless', defined at:
  File "main.py", line 57, in <module>
    main()
  File "main.py", line 48, in main
    model = egan.EGAN(args, sess)
  File "xxx\src\models\EGAN.py", line 9, in __init__
    Operator.__init__(self, args, sess)
  File "xxx\src\operator\op_EGAN.py", line 16, in __init__
    self.build_model()
  File "xxx\src\operator\op_EGAN.py", line 167, in build_model
    self.opt_q =  YFOptimizer().minimize(self.vae_discriminator_loss, var_list=q_vars)
  File "xxx\src\yellowfin.py", line 267, in minimize
    return self.apply_gradients(grads_and_vars)
  File "xxx\src\yellowfin.py", line 222, in apply_gradients
    update_hyper_op = self.update_hyper_param()
  File "xxx\src\yellowfin.py", line 190, in update_hyper_param
    lambda: self._mu_var) )
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1741, in cond
    orig_res, res_t = context_t.BuildCondBranch(fn1)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1642, in BuildCondBranch
    r = fn()
  File "xxx\src\yellowfin.py", line 189, in <lambda>
    self._mu = tf.identity(tf.cond(self._do_tune, lambda: self.get_mu_tensor(),
  File "xxx\src\yellowfin.py", line 173, in get_mu_tensor
    roots = tf.py_func(np.roots, [coef], Tout=tf.complex64, stateful=False)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\ops\script_ops.py", line 193, in py_func
    input=inp, token=token, Tout=Tout, name=name)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\ops\gen_script_ops.py", line 60, in _py_func_stateless
    Tout=Tout, name=name)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 768, in apply_op
    op_def=op_def)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "C:\Miniconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

InternalError (see above for traceback): Failed to run py callback pyfunc_1: see error log.
         [[Node: update_hyper_1/cond/PyFuncStateless = PyFuncStateless[Tin=[DT_FLOAT], Tout=[DT_COMPLEX64], token="pyfunc_1", _device="/job:l
ocalhost/replica:0/task:0/cpu:0"](update_hyper_1/cond/ScatterUpdate/_875)]]
         [[Node: update_hyper_1/cond/Gather_1/_883 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send
_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1626_update_hyper_1/cond/Gather_1", tensor_type
=DT_COMPLEX64, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

PS xxx>

I am not sure if it is this PR or another, but I can only run YFOptimizer with this PR's code.

jmhessel commented 7 years ago

@JDvorak -- quick clarification question; this happens with this PR and the master repo? Or just with this PR?

JianGoForIt commented 7 years ago

@jmhessel Thanks for the PR.

@JDvorak, looking from my side, that is mostly like an exploding gradient or zero gradient issue.

Could you do the following:

  1. Observe what is the typical magnitude of gradient, observe the magnitude that triggers the error.

  2. using the clip_thresh argument to setup gradient clipping for YFOptimizer.

If you could catch the exception and redo the specific iteration, it should also help solve the problem. Not sure what is the proper way to do exception handling in Tensorflow :).

jinxin0924 commented 7 years ago

Hi, @jmhessel I tried resnet18 on cifar10 in keras using TFOptimizer(YFOptimizer()), but got much worse result compared to adam. Did you try to compare YellowFin with other optmizers in keras?

JianGoForIt commented 7 years ago

Hi @jinxin0924,

In the readme, we actually recommend YFOptimizer(learning_rate=1.0, momentum=0.0), but I am not sure how you should use it in Keras.

Cheers,

jinxin0924 commented 7 years ago

@JianGoForIt I used YellowFin just like what jmhessel said, model.compile(loss='mse',opt=TFOptimizer(YFOptimizer()))

and added compute_gradients function in yellowfin.py.

JDvorak commented 7 years ago

@JianGoForIt When I am back from vacation, I'll get you those answers. Otherwise, great work! I'm sure I am not alone in wishing to be rid of worrying about the learning rate hyperparameter.

jmhessel commented 7 years ago

@jinxin0924 @JianGoForIt I don't know for sure if running with the recommended parameters will solve all of the performance issues, but, it's quite easy to run with the recommended parameters using keras.

model.compile(loss='mse', opt=TFOptimizer(YFOptimizer(learning_rate=1.0, momentum=0.0)))

should do the trick. However, why not just make the default parameters those?