Open chrisrn opened 6 years ago
This is a good idea! I think the right approach would to have one Tensorflow local variable per GPU per variable. Each step, each gradient on a GPU would be added to the corresponding local variable. Every 10 steps, the local variables would be aggregated and reset back to 0.
It's be easiest to start by only supporting non-distributed mode. One way to implement this is to only run this chunk of code once every 10 steps. The other 9 steps, you would just accumulate device_grads into the local variables, with one local variable per gradient. When writing to local variables, use tf.colocate_with
or tf.device
to ensure the local variable is on the same device as the gradient. On the 10th step, you would read from the local variable instead of using device_grads. I'm pretty sure this approach would work, but I may have missed some details.
An alternative approach would be to implement the functionality in VariableMgrLocalReplicated. In preprocess_device_grads, only return a non-empty list of devices every 10 steps, and in the other 9 steps, just accumulate the gradients into local variables.
Thanks for the fast response! I guess you need tf.cond
to do that because you need to accumulate gradients inside the tf graph right?
A tf.cond
, with global_step % 10 == 0
as the condition would work. Alternatively, you could have two fetch ops, X and Y. Op X would apply the gradients from the local variables, and you would run the op once every 10 steps. Op Y would accumulate the gradients in the local variables, and you would run it the other 9 steps (or perhaps every step, so you would run X and Y every 10 steps).
Is there a straight forward way to achieve this in Tensorflow 2?
A really good improvement would be to integrate into variable_mgr the above functionality. This means that you can use for example batch size 64, aggregate the gradients of 10 steps and then apply them which means that the actual batch size would be 640. In this way, you can apply bigger batch size without memory allocations. I would like to implement it. Any tips on that are appreciated!