I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.
For example, I was thinking something like should be done instead:
def acc_grad_and_loss(i, l_and_g, rng_key):
imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
(step_size,) + images.shape[1:])
lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
(step_size, labels.shape[1]))
# if loss has stochasticity, it should have a different random seed for each accumulation step
params['rng_key'] = rng_key
rng_key, = jax.random.split(rng_key, 1)
li, gi = loss_and_grad_fn(params, imgs, lbls)
l, g = l_and_g
return (l + li, jax.tree_map(lambda x, y: x + y, g, gi), rng_key)
l, g, rng_key = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g), rng_key)
return jax.tree_map(lambda x: x / accum_steps, (l, g))
Hi,
I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.
For example, I was thinking something like should be done instead: