JianGoForIt / YellowFin

auto-tuning momentum SGD optimizer
Apache License 2.0
422 stars 93 forks source link

Monitor gradient global norm #20

Open mfernezir opened 7 years ago

mfernezir commented 7 years ago

EDIT:

@JianGoForIt
I've just updated this pull request to sync with the current master.

This adds the functionality to monitor global gradient norm, both clipped and original. It works for both adaptive and manual clipping. I find it useful to observe gradient norm charts in Tensorflow, along with other metrics.

To summarize it in TF, you can just use

tf.summary.scalar("gradient_norm", optimizer.grad_norm_monitor)
tf.summary.scalar("clipped_gradient_norm", optimizer.clipped_grad_norm_monitor)

with initialized optimizer = YFOptimizer(...)

OLD COMMENTS BELOW, NVM

Context

Sometimes it is useful to use gradient clipping regulated by clip_thresh, which is used by tf.clip_by_global_norm as clip_norm parameter. If tf.global_norm(gradients) > clip_norm, gradients are shrinked by clip_norm / tf.global_norm(gradients). https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm

Problem and proposed changes

It may not be so clear what is the appropriate value for clip_thresh to prevent NaN values. To help with that, I have added code to monitor current gradient global norm. Specifically, if clip_thresh is not None, gradient clipping is possible and the calculated self._grads_norm is reachable by self.grad_global_norm_monitor.

I'll post some screenshots to show that everything works okay. I have an ongoing training which compares this branch with the current master. I am tracking internal optimizer values with this setup:

optimizer = YFOptimizer(learning_rate=0.05, momentum=0.0, clip_thresh=10) tf.summary.scalar("learning_rate", optimizer._lr_var) tf.summary.scalar("momentum", optimizer._mu_var) tf.summary.scalar("global_gradient_norm", optimizer.grad_global_norm_monitor)

mfernezir commented 7 years ago

As promised, here are some screenshots showing that everything works okay. This commit just adds global norm monitoring. Like before, I am using Python 3 and TF 1.2r. Two GPUs for the current master and two for this branch.

accuracy gradient_norm_lr_mu total_loss

mfernezir commented 6 years ago

I've updated this PR to include recent changes to the master branch.