Closed wx405557858 closed 3 years ago
I think it's solved by this https://github.com/BVLC/caffe/pull/1663. It will accumulate the gradient for the whole large batch and update. @fchollet If I implement this part with keras source code, which file should I change, thanks a lot!
I solved this by change the optimizer.py.
@wx405557858 I'm curious how you did this. I hacked something together that seems to work, but I'd be interested in a better way. Also it might be useful to have Keras. Here is how I did it below.
Basically accum_switch
turns to 1 every set number of epochs and the updates either update with the old value or the new: self.updates.append(K.update(m, (1-accum_switch)*m + accum_switch*m_t))
This avoids any logic for the backend to deal with at the expense of some unnecessary calculations (g_prime, etc) that are discarded between actual updates.
class NadamAccum(Optimizer):
'''
Nesterov Adam optimizer: Much like Adam is essentially RMSprop with momentum,
Nadam is Adam RMSprop with Nesterov momentum.
Default parameters follow those provided in the paper.
It is recommended to leave the parameters of this optimizer
at their default values.
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
'''
def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, schedule_decay=0.004, accum_iters=1, **kwargs):
super(NadamAccum, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0.)
self.m_schedule = K.variable(1.)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.schedule_decay = schedule_decay
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
t = (self.iterations + 1.)/self.accum_iters
accum_switch = K.floor((self.accum_iters - K.mod(self.iterations + 1., self.accum_iters))/self.accum_iters)
# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = self.beta_1 * (1. - 0.5 * (K.pow(0.96, t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (1. - 0.5 * (K.pow(0.96, (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, accum_switch*m_schedule_new + (1-accum_switch)*self.m_schedule))
shapes = [x.shape for x in K.batch_get_value(params)]
ms = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
gs = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + ms + vs
for p, gp, m, v, ga in zip(params, grads, ms, vs, gs):
g = (ga + gp)/self.accum_iters
# the following equations given in [1]
g_prime = g / (1. - m_schedule_new)
m_t = self.beta_1 * m + (1. - self.beta_1) * g
m_t_prime = m_t / (1. - m_schedule_next)
v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
self.updates.append(K.update(m, (1-accum_switch)*m + accum_switch*m_t))
self.updates.append(K.update(v, (1-accum_switch)*v + accum_switch*v_t))
self.updates.append(K.update(ga, (1-accum_switch)*(ga + gp)))
p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append(K.update(p, (1-accum_switch)*p + accum_switch*new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay,
'accum_iters': self.accum_iters}
base_config = super(NadamAccum, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@the-moliver Yeah, we did exactly the same! I have a flag calculated by (self.iteration % accum_iters) == 0 . It will turn into 1 after accum_iters batches. I think maybe can write a wrapper to wrap every optimizer and change the updates base on accum_iters. Or just implement each optimizer's _accum version. There's only several optimizers.
class Adam_accumulate(Optimizer):
'''Adam accumulate optimizer.
Default parameters follow those provided in the original paper. Wait for several mini-batch to update
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=5, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
print t.eval()
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * m))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
# print self.updates
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@wx405557858 I tried using your code, the loss seems to explode.
@raghakot It works for my model. I assume it should be universal. Would the loss converge with normal Adam optimizer in your case?
Yes. It converges with regular Adam. @the-moliver version seems to work too.
I have to make a tiny change to your code to cast the flag to float32 (fails to run otherwise due to dtype mismatch on arithmetic operations with flag
). This is on bleeding edge keras...maybe something changed? Also, I am using tensorflow backend, if that matters.
@raghakot Thanks for your pointing out. I'm not quite sure what's the exact problem. But it's nice to know the-moliver's solution works for you.
Set flag = K.cast(flag, dtype='float32') and it works. Thanks wx405557858
Thank you for your sharing. I am new here, but I have several trouble at first. What is the relation between accum_iters
and the final batch_size? @wx405557858 @the-moliver
final batch_size = accum_iters * original batch_size
Hi, @wx405557858 ,could you please show your optimizers.py
? I changed the file just like you did, but ValueError: ('Could not interpret optimizer identifier:', <AdamAccum.AdamAccum object at 0x0000015F1E58DB00>)
@soon-will the optimizers.py
is from keras. see here.
@wx405557858 what you had: self.updates.append((v, flag v_t + (1 - flag) m))
shouldn't m be v? self.updates.append((v, flag v_t + (1 - flag) v))
@the-moliver, I am getting an error K.floor doesnt exist on this line:
accum_switch = K.floor((self.accum_iters - K.mod(self.iterations + 1., self.accum_iters))/self.accum_iters)
Was K.floor and K.mod recently removed from Keras backend? Cant find them here: https://github.com/fchollet/keras/tree/master/keras/backend
@jackkwok On my case without the fix that you suggest I get Nan on loss and metrics. Using the fix it works.
This feature extremely useful and must be added in official repository.
Code by @wx405557858 with fixes. I checked it in my project and it seemed to work fine:
from keras.optimizers import Optimizer
from keras import backend as K
import numpy as np
class Adam_accumulate(Optimizer):
'''Adam accumulate optimizer.
Default parameters follow those provided in the original paper. Wait for several mini-batch to update
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=10, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
flag = K.cast(flag, dtype='float32')
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Thanks @ZFTurbo for the fixes.
This is version of code for Keras 2.0.8 with fixed constraints issue and get_updates parameters.
from keras.optimizers import Optimizer
from keras import backend as K
import numpy as np
class Adam_accumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=20, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
flag = K.cast(flag, dtype='float32')
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if getattr(p, 'constraint', None) is not None:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Hi Guys, thanks for the previous code, i have been trying to replicate the same for SGD with nestrov,
class SGDAccum(Optimizer):
"""Stochastic gradient descent optimizer.
Includes support for momentum,
learning rate decay, and Nesterov momentum.
# Arguments
lr: float >= 0. Learning rate.
momentum: float >= 0. Parameter updates momentum.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""
def __init__(self, lr=0.01, momentum=0., decay=0.,
nesterov=False, accum_iters=1, **kwargs):
super(SGDAccum, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, name='iterations')
self.lr = K.variable(lr, name='lr')
self.momentum = K.variable(momentum, name='momentum')
self.decay = K.variable(decay, name='decay')
self.accum_iters = K.variable(accum_iters)
self.initial_decay = decay
self.nesterov = nesterov
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
lr *= (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
accum_switch = K.equal(self.iterations % self.accum_iters, 0)
accum_switch = K.cast(accum_switch, dtype='float32')
# momentum
shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
temp_grads = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + moments
for p, cg, m, tg in zip(params, grads, moments, temp_grads):
g = cg + tg
v = self.momentum * m - (lr * g / self.accum_iters) # velocity
self.updates.append(K.update(m, (1 - accum_switch) * m + accum_switch * v))
self.updates.append(K.update(tg, (1 - accum_switch) * g))
if self.nesterov:
new_p = p + self.momentum * v - (lr * g / self.accum_iters)
else:
new_p = p + v
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, (1 - accum_switch) * p + accum_switch * new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'momentum': float(K.get_value(self.momentum)),
'decay': float(K.get_value(self.decay)),
'nesterov': self.nesterov,
'accum_iters': self.accum_iters}
base_config = super(SGDAccum, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Can someone please verify that it look's about right ?
@gamers5a your function doesn't work in latest Keras version. There were to much changes in Adam function between 1.2.1 and 2.0.8 versions. Hope someone fix it as well.
@viig99 I believe your functions works just fine. Here is the logs of 3 runs:
SGD (default, batch=32):
Epoch 1/200
1/400 [..............................] - ETA: 4011s - loss: 0.6939 - acc: 0.4648
2/400 [..............................] - ETA: 2864s - loss: 0.6941 - acc: 0.4492
3/400 [..............................] - ETA: 2465s - loss: 0.6940 - acc: 0.4557
4/400 [..............................] - ETA: 2262s - loss: 0.6939 - acc: 0.4561
5/400 [..............................] - ETA: 2136s - loss: 0.6939 - acc: 0.4552
6/400 [..............................] - ETA: 2047s - loss: 0.6938 - acc: 0.4627
7/400 [..............................] - ETA: 1984s - loss: 0.6938 - acc: 0.4687
8/400 [..............................] - ETA: 1932s - loss: 0.6937 - acc: 0.4728
9/400 [..............................] - ETA: 1891s - loss: 0.6936 - acc: 0.4796
10/400 [..............................] - ETA: 1866s - loss: 0.6936 - acc: 0.4827
11/400 [..............................] - ETA: 1842s - loss: 0.6935 - acc: 0.4878
12/400 [..............................] - ETA: 1819s - loss: 0.6934 - acc: 0.4935
13/400 [..............................] - ETA: 1802s - loss: 0.6933 - acc: 0.4980
14/400 [>.............................] - ETA: 1785s - loss: 0.6932 - acc: 0.5041
15/400 [>.............................] - ETA: 1770s - loss: 0.6931 - acc: 0.5088
16/400 [>.............................] - ETA: 1755s - loss: 0.6931 - acc: 0.5149
17/400 [>.............................] - ETA: 1742s - loss: 0.6930 - acc: 0.5188
18/400 [>.............................] - ETA: 1732s - loss: 0.6929 - acc: 0.5242
19/400 [>.............................] - ETA: 1719s - loss: 0.6929 - acc: 0.5288
20/400 [>.............................] - ETA: 1710s - loss: 0.6928 - acc: 0.5337
21/400 [>.............................] - ETA: 1701s - loss: 0.6927 - acc: 0.5397
22/400 [>.............................] - ETA: 1688s - loss: 0.6926 - acc: 0.5461
23/400 [>.............................] - ETA: 1678s - loss: 0.6925 - acc: 0.5517
24/400 [>.............................] - ETA: 1669s - loss: 0.6924 - acc: 0.5575
25/400 [>.............................] - ETA: 1660s - loss: 0.6923 - acc: 0.5634
26/400 [>.............................] - ETA: 1653s - loss: 0.6922 - acc: 0.5693
27/400 [=>............................] - ETA: 1646s - loss: 0.6921 - acc: 0.5746
28/400 [=>............................] - ETA: 1638s - loss: 0.6920 - acc: 0.5790
29/400 [=>............................] - ETA: 1631s - loss: 0.6919 - acc: 0.5850
30/400 [=>............................] - ETA: 1623s - loss: 0.6918 - acc: 0.5903
31/400 [=>............................] - ETA: 1615s - loss: 0.6917 - acc: 0.5958
32/400 [=>............................] - ETA: 1609s - loss: 0.6916 - acc: 0.6015
33/400 [=>............................] - ETA: 1603s - loss: 0.6915 - acc: 0.6067
34/400 [=>............................] - ETA: 1598s - loss: 0.6914 - acc: 0.6125
35/400 [=>............................] - ETA: 1593s - loss: 0.6912 - acc: 0.6177
36/400 [=>............................] - ETA: 1587s - loss: 0.6911 - acc: 0.6230
37/400 [=>............................] - ETA: 1581s - loss: 0.6910 - acc: 0.6276
38/400 [=>............................] - ETA: 1580s - loss: 0.6909 - acc: 0.6315
39/400 [=>............................] - ETA: 1575s - loss: 0.6908 - acc: 0.6358
40/400 [==>...........................] - ETA: 1572s - loss: 0.6907 - acc: 0.6399
SGDAccum (accum_iters=1, batch=32)
1/400 [..............................] - ETA: 3341s - loss: 0.6939 - acc: 0.4648
...
40/400 [==>...........................] - ETA: 1545s - loss: 0.6907 - acc: 0.6399
SGDAccum (accum_iters=2, batch=16)
Epoch 1/200
1/400 [..............................] - ETA: 2258s - loss: 0.6937 - acc: 0.4661
2/400 [..............................] - ETA: 1539s - loss: 0.6939 - acc: 0.4544
3/400 [..............................] - ETA: 1304s - loss: 0.6940 - acc: 0.4523
4/400 [..............................] - ETA: 1184s - loss: 0.6940 - acc: 0.4538
5/400 [..............................] - ETA: 1110s - loss: 0.6940 - acc: 0.4505
6/400 [..............................] - ETA: 1062s - loss: 0.6941 - acc: 0.4466
7/400 [..............................] - ETA: 1020s - loss: 0.6941 - acc: 0.4509
8/400 [..............................] - ETA: 993s - loss: 0.6940 - acc: 0.4544
9/400 [..............................] - ETA: 970s - loss: 0.6940 - acc: 0.4563
10/400 [..............................] - ETA: 956s - loss: 0.6940 - acc: 0.4557
11/400 [..............................] - ETA: 939s - loss: 0.6939 - acc: 0.4614
12/400 [..............................] - ETA: 928s - loss: 0.6938 - acc: 0.4672
13/400 [..............................] - ETA: 916s - loss: 0.6938 - acc: 0.4700
14/400 [>.............................] - ETA: 907s - loss: 0.6938 - acc: 0.4708
15/400 [>.............................] - ETA: 899s - loss: 0.6937 - acc: 0.4703
16/400 [>.............................] - ETA: 892s - loss: 0.6937 - acc: 0.4740
17/400 [>.............................] - ETA: 885s - loss: 0.6937 - acc: 0.4738
18/400 [>.............................] - ETA: 877s - loss: 0.6936 - acc: 0.4766
19/400 [>.............................] - ETA: 874s - loss: 0.6936 - acc: 0.4779
20/400 [>.............................] - ETA: 868s - loss: 0.6936 - acc: 0.4794
21/400 [>.............................] - ETA: 863s - loss: 0.6936 - acc: 0.4820
22/400 [>.............................] - ETA: 856s - loss: 0.6935 - acc: 0.4843
23/400 [>.............................] - ETA: 851s - loss: 0.6935 - acc: 0.4887
24/400 [>.............................] - ETA: 847s - loss: 0.6934 - acc: 0.4909
25/400 [>.............................] - ETA: 842s - loss: 0.6934 - acc: 0.4928
26/400 [>.............................] - ETA: 838s - loss: 0.6934 - acc: 0.4964
27/400 [=>............................] - ETA: 835s - loss: 0.6933 - acc: 0.4986
28/400 [=>............................] - ETA: 830s - loss: 0.6933 - acc: 0.5019
29/400 [=>............................] - ETA: 827s - loss: 0.6933 - acc: 0.5048
30/400 [=>............................] - ETA: 823s - loss: 0.6932 - acc: 0.5073
31/400 [=>............................] - ETA: 820s - loss: 0.6932 - acc: 0.5098
32/400 [=>............................] - ETA: 817s - loss: 0.6931 - acc: 0.5131
33/400 [=>............................] - ETA: 814s - loss: 0.6931 - acc: 0.5156
34/400 [=>............................] - ETA: 811s - loss: 0.6930 - acc: 0.5193
35/400 [=>............................] - ETA: 808s - loss: 0.6930 - acc: 0.5231
36/400 [=>............................] - ETA: 806s - loss: 0.6929 - acc: 0.5263
37/400 [=>............................] - ETA: 802s - loss: 0.6929 - acc: 0.5296
38/400 [=>............................] - ETA: 798s - loss: 0.6928 - acc: 0.5330
39/400 [=>............................] - ETA: 795s - loss: 0.6928 - acc: 0.5366
40/400 [==>...........................] - ETA: 791s - loss: 0.6927 - acc: 0.5401
41/400 [==>...........................] - ETA: 789s - loss: 0.6927 - acc: 0.5434
42/400 [==>...........................] - ETA: 786s - loss: 0.6926 - acc: 0.5464
43/400 [==>...........................] - ETA: 782s - loss: 0.6926 - acc: 0.5506
44/400 [==>...........................] - ETA: 780s - loss: 0.6925 - acc: 0.5537
45/400 [==>...........................] - ETA: 778s - loss: 0.6925 - acc: 0.5563
46/400 [==>...........................] - ETA: 775s - loss: 0.6924 - acc: 0.5601
47/400 [==>...........................] - ETA: 772s - loss: 0.6924 - acc: 0.5631
48/400 [==>...........................] - ETA: 770s - loss: 0.6923 - acc: 0.5662
49/400 [==>...........................] - ETA: 766s - loss: 0.6923 - acc: 0.5690
50/400 [==>...........................] - ETA: 764s - loss: 0.6922 - acc: 0.5713
51/400 [==>...........................] - ETA: 761s - loss: 0.6922 - acc: 0.5733
52/400 [==>...........................] - ETA: 758s - loss: 0.6921 - acc: 0.5763
53/400 [==>...........................] - ETA: 756s - loss: 0.6921 - acc: 0.5784
54/400 [===>..........................] - ETA: 753s - loss: 0.6920 - acc: 0.5812
55/400 [===>..........................] - ETA: 751s - loss: 0.6920 - acc: 0.5838
56/400 [===>..........................] - ETA: 748s - loss: 0.6919 - acc: 0.5864
57/400 [===>..........................] - ETA: 746s - loss: 0.6919 - acc: 0.5894
58/400 [===>..........................] - ETA: 743s - loss: 0.6918 - acc: 0.5920
59/400 [===>..........................] - ETA: 740s - loss: 0.6918 - acc: 0.5948
60/400 [===>..........................] - ETA: 738s - loss: 0.6917 - acc: 0.5978
61/400 [===>..........................] - ETA: 735s - loss: 0.6917 - acc: 0.6001
62/400 [===>..........................] - ETA: 732s - loss: 0.6916 - acc: 0.6029
63/400 [===>..........................] - ETA: 729s - loss: 0.6916 - acc: 0.6054
64/400 [===>..........................] - ETA: 726s - loss: 0.6915 - acc: 0.6079
65/400 [===>..........................] - ETA: 725s - loss: 0.6915 - acc: 0.6105
66/400 [===>..........................] - ETA: 722s - loss: 0.6914 - acc: 0.6124
67/400 [====>.........................] - ETA: 719s - loss: 0.6914 - acc: 0.6151
68/400 [====>.........................] - ETA: 716s - loss: 0.6913 - acc: 0.6175
69/400 [====>.........................] - ETA: 714s - loss: 0.6913 - acc: 0.6208
70/400 [====>.........................] - ETA: 711s - loss: 0.6912 - acc: 0.6227
71/400 [====>.........................] - ETA: 709s - loss: 0.6912 - acc: 0.6252
72/400 [====>.........................] - ETA: 707s - loss: 0.6911 - acc: 0.6272
73/400 [====>.........................] - ETA: 704s - loss: 0.6911 - acc: 0.6294
74/400 [====>.........................] - ETA: 702s - loss: 0.6910 - acc: 0.6315
75/400 [====>.........................] - ETA: 700s - loss: 0.6910 - acc: 0.6339
76/400 [====>.........................] - ETA: 698s - loss: 0.6909 - acc: 0.6357
77/400 [====>.........................] - ETA: 696s - loss: 0.6909 - acc: 0.6381
78/400 [====>.........................] - ETA: 694s - loss: 0.6908 - acc: 0.6402
79/400 [====>.........................] - ETA: 692s - loss: 0.6907 - acc: 0.6416
80/400 [=====>........................] - ETA: 690s - loss: 0.6907 - acc: 0.6438
But there is problem with model.save() method:
TypeError: ('Not JSON Serializable:', SGDAccum/variable)
That will have to be included in the optimizers.py file, in the serialize and de-serialize methods. I would like to point out that batch accumulation is an incredibly useful option and should be provided with the main package, can we improve the visibility on this, or is their a better / preferred way to restructure the code ?
@viig99 may be you can try to add your changes directly in SGD optimizer in official repository as pull request. Because SGDAccum with default accum_iters=1
has the same behavior as standard SGD optimizer.
Hi @viig99, thanks for the SGDAccum code. I am getting the same error as @ZFTurbo when trying model.save():
TypeError: ('Not JSON Serializable:', SGDAccum/variable)
I am using Keras 2.1.1. Can you show your optimizers.py
code, please?
https://www.hastebin.com/efabasizas.py this is the one i was using, i am pretty sure there are better ways of doing things, for now i am saving weights and restarting networks with those weights.
@noagarcia @viig99
I think the reason unable to save is that
'accum_iters': self.accum_iters
should be
'accum_iters': int(K.get_value(self.accum_iters))
However, even I could save the model, when I load the model, it still ended with error: unknown optimizer : SGDAccum
First of all, very happy that I found this thread - great stuff! Thanks all for sharing :)
Wondering - performance wise - isn't it better to use K.switch instead of
self.updates.append(K.update(p, (1 - accum_switch) * p + accum_switch * new_p))
?
For example, something of this spirit:
maybe_assign_params = K.switch(
self.iterations%self.accum_iters == 0,
K.update(p, new_p),
K.update_add(tiny_dummy_param,0) #or some other dummy no-op
)
self.updates.append(maybe_assign_params)
to avoid doing K.update of all parameters into themselves for every n-1/n of the steps.
Can it be used along with batch normalization or do I need to change it a bit??
Using one of these solutions, should the loss not improve until the weights are updated every K batches? I tried @gamers5a 's solution and my loss improves every batch, even when I choose a large value for accum_iters. I'm not sure about this.
Thx guys! I'm using SGD provided by @viig99 and it works nicely!
Imho it should be part of keras itself though.
I try to use the adam
optimizers above, but none of them work for the new version v2.2.2.
I use this for 2.2.2:
from keras.legacy import interfaces
from keras.optimizers import Optimizer
from keras import backend as K
class AdamAccumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, accum_iters=20, **kwargs):
super(AdamAccumulate, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.accum_iters = K.variable(accum_iters, dtype='int64')
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
(1. - K.pow(self.beta_1, t)))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat, gg in zip(params, grads, ms, vs, vhats, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
flag = K.cast(flag, K.floatx())
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / K.cast(self.accum_iters, K.floatx())
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / K.cast(self.accum_iters, K.floatx()))
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))`
With nik-ko's version:
from keras_optim_acc import AdamAccumulate
on model.compile(optimizer='AdamAccumulate' ...) I get,
... python2.7/site-packages/keras/utils/generic_utils.py", line 138, in deserialize_keras_object
': ' + class_name)
ValueError: Unknown optimizer: AdamAccumulate
@phobrain optimizer=AdamAccumulate(), not optimizer='AdamAccumulate'
I had a complaint about something being used twice when using AdamAccumulate with a shared/siamese component of my model. The general setup is here:
https://www.reddit.com/r/MachineLearning/comments/9p9xh4/d_lstm_for_sequence_of_images/
Will reproduce and paste the error when GPU is free. :-)
@nik-ko If I set accum_iters
to, say, 4 - it should update weights only after every 4 batches?
I use this callback and weights are updated after each batch for some reason:
class ModelWeightsCallback(Callback):
def on_batch_end(self, batch, logs=None):
print('\n\nweights:\n')
print(self.model.get_weights())
Or maybe somebody else could check that code?
Weights were updated after each batch, because in that code the flag
was missed here:
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - flag * lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
Another thing that is not clear - why Adam
and AdamAccumulate
are getting different results.
For testing I use samples in the same order, don't use shuffle and copy initial wights, then run model.fit()
two times with different optimizers. Adam runned twice reproduces its results almost exactly. But Adam(with batch=32)
and AdamAccumulate(with batch=4, accum_iters=8)
give different results.
Shouldn't they get almost the same results? So I'm not sure if the code of optimizer is correct...
@alexeydevederkin - I also tried to use Adam with accumulated gradients presented here. When I try different experiments I have the same training accuracy but when the model goes through the validation portion the validation results are off. I am not sure if this is expected or not.
In general I wouldn't expect optimizers to even give the same result from run to run, let alone agree, but it would be interesting to build up from a simple net and see if there is more divergence when more params are being initialized. On Friday, November 9, 2018, 3:29:37 AM PST, alexeydevederkin notifications@github.com wrote:
Weights were updated after each batch, because in that code the flag was missed here: if self.amsgrad: vhat_t = K.maximum(vhat, v_t) p_t = p - flag lr_t m_t / (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: p_t = p - flag lr_t m_t / (K.sqrt(v_t) + self.epsilon) Another thing that is not clear - why Adam and AdamAccumulate are getting different results.
For testing I use samples in the same order, don't use shuffle and copy initial wights, then run model.fit() two times with different optimizers. Adam runned twice reproduces its results almost exactly. But Adam(with batch=32) and AdamAccumulate(with batch=4, accum_iters=8) give different results.
Shouldn't they get almost the same results? So I'm not sure if the code of optimizer is correct...
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or mute the thread.
@phobrain - I agree with that statement but for a little more context in terms of accuracy for the particular project I'm working on the validation accuracy for regular Adam will be around 0.68 - 0.69 but with Adam accumulation I obtain 0.71 - 0.72. The discrepancy becomes higher the more accumulation rounds I add. I guess my original question is - is this type of discrepancy to high or expected when using accumulation.
Thanks Ryan, is that validation accuracy on a held-out test set using predictions, or just when fitting? The latter I don't consider super meaningful. On Wednesday, November 14, 2018, 2:35:10 PM PST, Ryan de Vera notifications@github.com wrote:
@phobrain - I agree with that statement but for a little more context in terms of accuracy for the particular project I'm working on the validation accuracy for regular Adam will be around 0.68 - 0.69 but with Adam accumulation I obtain 0.71 - 0.72. The discrepancy becomes higher the more accumulation rounds I add. I guess my original question is - is this type of discrepancy to high or expected when using accumulation.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or mute the thread.
@phobrain - this is the validation accuracy on a held-out test set using predictions.
@phobrain optimizer=AdamAccumulate(), not optimizer='AdamAccumulate'
I'm getting the "ValueError: ('Could not interpret optimizer identifier:', <main.AdamAccumulate object at 0x00000000FD682E10>)"
Would you happen to know about this one?
same error with @adityaparikh1
@adityaparikh1 I found this error occur when it is checked instance of optimizer. If you use tensorflow.keras library, you should imfort "from tensorflow.keras.optimizers import Optimizer" not "from keras.optimizers import Optimizer". It works for me. But I also have problem that gradient update every epoch.
@rydevera3 @phobrain I am using this code to test optimizer:
import keras.backend as K
import numpy as np
import tensorflow as tf
import random as rn
# Reproducibility
# https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
np.random.seed(42)
rn.seed(12345)
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
tf.set_random_seed(1234)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
from keras import models, layers
model = models.Sequential()
model.add(layers.Conv2D(8, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
from keras.datasets import mnist
from keras.utils import to_categorical
(train_images, train_labels), _ = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
model_2 = models.clone_model(model)
model_2.set_weights(model.get_weights())
model_3 = models.clone_model(model)
model_3.set_weights(model.get_weights())
optimizer = Adam(lr=0.0001)
model.compile(
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with Adam, 1st run:')
model.fit(train_images, train_labels, epochs=5, batch_size=32, shuffle=False)
optimizer_2 = Adam(lr=0.0001)
model_2.compile(
optimizer=optimizer_2,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with Adam, 2nd run:')
model_2.fit(train_images, train_labels, epochs=5, batch_size=32, shuffle=False)
optimizer_3 = AdamAccumulate(lr=0.0001, accum_iters=8)
model_3.compile(
optimizer=optimizer_3,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with AdamAccumulate:')
model_3.fit(train_images, train_labels, epochs=5, batch_size=4, shuffle=False)
Also run it with env variables: $ CUDA_VISIBLE_DEVICES="" PYTHONHASHSEED=0 python3 optimizer_test.py
What I got:
Training with Adam, 1st run:
Epoch 1/5
60000/60000 [==============================] - 79s 1ms/step - loss: 1.3168 - acc: 0.6004
Epoch 2/5
60000/60000 [==============================] - 76s 1ms/step - loss: 0.4745 - acc: 0.8595
Epoch 3/5
60000/60000 [==============================] - 79s 1ms/step - loss: 0.3572 - acc: 0.8944
Epoch 4/5
60000/60000 [==============================] - 77s 1ms/step - loss: 0.3018 - acc: 0.9104
Epoch 5/5
60000/60000 [==============================] - 76s 1ms/step - loss: 0.2672 - acc: 0.9201
Training with Adam, 2nd run:
Epoch 1/5
60000/60000 [==============================] - 75s 1ms/step - loss: 1.3168 - acc: 0.6004
Epoch 2/5
60000/60000 [==============================] - 75s 1ms/step - loss: 0.4745 - acc: 0.8595
Epoch 3/5
60000/60000 [==============================] - 78s 1ms/step - loss: 0.3572 - acc: 0.8944
Epoch 4/5
60000/60000 [==============================] - 79s 1ms/step - loss: 0.3018 - acc: 0.9104
Epoch 5/5
60000/60000 [==============================] - 77s 1ms/step - loss: 0.2672 - acc: 0.9201
Training with AdamAccumulate:
Epoch 1/5
60000/60000 [==============================] - 150s 3ms/step - loss: 0.9540 - acc: 0.7108
Epoch 2/5
60000/60000 [==============================] - 161s 3ms/step - loss: 0.4133 - acc: 0.8761
Epoch 3/5
60000/60000 [==============================] - 164s 3ms/step - loss: 0.3300 - acc: 0.9022
Epoch 4/5
60000/60000 [==============================] - 140s 2ms/step - loss: 0.2857 - acc: 0.9147
Epoch 5/5
60000/60000 [==============================] - 155s 3ms/step - loss: 0.2563 - acc: 0.9232
As you can see Adam reproduces itsef exactly, but AdamAccumulate gives different results.
I noticed some mistakes in the code of optimizer, will post my version later, just need to fix some strange behavior. Hard to debug TF code)
Hey everyone, I've corrected some bugs in @nik-ko 's implementation (mainly the learning rate which wasn't adjusting correctly). Here it is:
from keras.legacy import interfaces
from keras.optimizers import Optimizer, Adam
from keras import backend as K
class AdamAccumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, accum_iters=20, **kwargs):
super(AdamAccumulate, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.effective_iterations = K.variable(0, dtype='int64', name='effective_iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.accum_iters = K.variable(accum_iters, dtype='int64')
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update(self.iterations, self.iterations + 1)]
flag = K.equal(self.iterations % self.accum_iters, self.accum_iters - 1)
flag = K.cast(flag, K.floatx())
self.updates.append(K.update(self.effective_iterations,
self.effective_iterations + K.cast(flag, 'int64')))
lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.effective_iterations,
K.dtype(self.decay))))
t = K.cast(self.effective_iterations, K.floatx()) + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
(1. - K.pow(self.beta_1, t )))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat, gg in zip(params, grads, ms, vs, vhats, gs):
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / K.cast(self.accum_iters, K.floatx())
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / K.cast(self.accum_iters, K.floatx()))
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - flag * lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
And using @alexeydevederkin 's test, everything seems to work almost perfectly:
Training with Adam, 1st run:
Epoch 1/1
60000/60000 [==============================] - 24s 402us/step - loss: 1.3166 - acc: 0.6004
Training with Adam, 2nd run:
Epoch 1/1
60000/60000 [==============================] - 24s 408us/step - loss: 1.3166 - acc: 0.6004
Training with AdamAccumulate:
Epoch 1/1
60000/60000 [==============================] - 148s 2ms/step - loss: 1.3139 - acc: 0.6004
With @Dutil 's code, I don't see my earlier-mentioned "complaint about something being used twice," tho other model details are different by now so that could be the cause, and I get reasonable results with my siamese model using keyword vectors, doubling the batch of 1024. In the same siamese model using VGG16, doubling batch of 32, on 1st try my held-back positive test cases all had the same value (0.01187402) which is binary-correct but too fishy. Rerunning, got two creditable epochs with hold-out testing between. But I see about the same run profile as for adagrad, so wondering if it makes sense (blindly QA'ing for now).
adagrad 11/4 15080/15080 3414s 226ms/step
AdamAcc 11/21 15394/15394 3190s 207ms/step
model.compile(optimizer=AdamAccumulate(accum_iters=2),
loss='binary_crossentropy',
metrics=['binary_accuracy']
#options=run_opts
)
Will try @alexeydevederkin 's version next.
My version of Adam optimizer with accumulated gradient (slightly different from @Dutil 's - closer results to Adam
)
import keras.backend as K
from keras.legacy import interfaces
from keras.optimizers import Optimizer
class AdamAccumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, accum_iters=1, **kwargs):
if accum_iters < 1:
raise ValueError('accum_iters must be >= 1')
super(AdamAccumulate, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.accum_iters = K.variable(accum_iters, K.dtype(self.iterations))
self.accum_iters_float = K.cast(self.accum_iters, K.floatx())
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
completed_updates = K.cast(K.tf.floordiv(self.iterations, self.accum_iters), K.floatx())
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * completed_updates))
t = completed_updates + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
# self.iterations incremented after processing a batch
# batch: 1 2 3 4 5 6 7 8 9
# self.iterations: 0 1 2 3 4 5 6 7 8
# update_switch = 1: x x (if accum_iters=4)
update_switch = K.equal((self.iterations + 1) % self.accum_iters, 0)
update_switch = K.cast(update_switch, K.floatx())
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat, tg in zip(params, grads, ms, vs, vhats, gs):
sum_grad = tg + g
avg_grad = sum_grad / self.accum_iters_float
m_t = (self.beta_1 * m) + (1. - self.beta_1) * avg_grad
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(avg_grad)
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, (1 - update_switch) * vhat + update_switch * vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, (1 - update_switch) * m + update_switch * m_t))
self.updates.append(K.update(v, (1 - update_switch) * v + update_switch * v_t))
self.updates.append(K.update(tg, (1 - update_switch) * sum_grad))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, (1 - update_switch) * p + update_switch * new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Tests:
Training with Adam, 1st run:
Epoch 1/5
60000/60000 [==============================] - 68s 1ms/step - loss: 1.3168 - acc: 0.6004
Epoch 2/5
60000/60000 [==============================] - 70s 1ms/step - loss: 0.4745 - acc: 0.8595
Epoch 3/5
60000/60000 [==============================] - 69s 1ms/step - loss: 0.3572 - acc: 0.8944
Epoch 4/5
60000/60000 [==============================] - 71s 1ms/step - loss: 0.3018 - acc: 0.9104
Epoch 5/5
60000/60000 [==============================] - 71s 1ms/step - loss: 0.2672 - acc: 0.9201
Training with Adam, 2nd run:
Epoch 1/5
60000/60000 [==============================] - 71s 1ms/step - loss: 1.3168 - acc: 0.6004
Epoch 2/5
60000/60000 [==============================] - 71s 1ms/step - loss: 0.4745 - acc: 0.8595
Epoch 3/5
60000/60000 [==============================] - 67s 1ms/step - loss: 0.3572 - acc: 0.8944
Epoch 4/5
60000/60000 [==============================] - 71s 1ms/step - loss: 0.3018 - acc: 0.9104
Epoch 5/5
60000/60000 [==============================] - 67s 1ms/step - loss: 0.2672 - acc: 0.9201
Training with AdamAccumulate:
Epoch 1/5
60000/60000 [==============================] - 141s 2ms/step - loss: 1.3167 - acc: 0.6004
Epoch 2/5
60000/60000 [==============================] - 141s 2ms/step - loss: 0.4744 - acc: 0.8596
Epoch 3/5
60000/60000 [==============================] - 136s 2ms/step - loss: 0.3572 - acc: 0.8944
Epoch 4/5
60000/60000 [==============================] - 139s 2ms/step - loss: 0.3018 - acc: 0.9105
Epoch 5/5
60000/60000 [==============================] - 138s 2ms/step - loss: 0.2671 - acc: 0.9201
I'm not very familiar with Tensorflow, but maybe it could be further improved (for speed) by using conditional updates instead of updating variables with the same values.
With @alexeydevederkin 's version on the VGG case, python 2.7:
File "../keras_optim_acc2.py", line 34, in get_updates
completed_updates = K.cast(K.tf.floor(self.iterations / self.accum_iters), K.floatx())
File "/home/phobrain/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2931, in floor
"Floor", x=x, name=name)
File "/home/phobrain/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper
param_name=input_name)
File "/home/phobrain/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'x' has DataType int64 not in list of allowed values: bfloat16, float16, float32, float64
Can keras support to update parameters after a relative large batch size which exceed the GPU memory if feeded in one time? My model now can only be feeded batch_size=4 samples a time due to GPU 12G memory. The loss is difficult to decline when batch_size=4. So I want to update the parameters after 32 samples. Will keras be able to support this? It seems that Caffe can support this. Thanks!