google-research / vision_transformer

Apache License 2.0
10.18k stars 1.27k forks source link

Shouldn't accumulate_gradient pass rng_key? #286

Open hrbigelow opened 1 year ago

hrbigelow commented 1 year ago

Hi,

I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.

For example, I was thinking something like should be done instead:

    def acc_grad_and_loss(i, l_and_g, rng_key):
      imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
                                   (step_size,) + images.shape[1:])
      lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
                                   (step_size, labels.shape[1]))
      # if loss has stochasticity, it should have a different random seed for each accumulation step
      params['rng_key'] = rng_key
      rng_key, = jax.random.split(rng_key, 1)
      li, gi = loss_and_grad_fn(params, imgs, lbls)
      l, g = l_and_g
      return (l + li, jax.tree_map(lambda x, y: x + y, g, gi), rng_key)

    l, g, rng_key = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g), rng_key)
    return jax.tree_map(lambda x: x / accum_steps, (l, g))