google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.34k stars 440 forks source link

OWL per-example grad clipping #1102

Closed DesertsP closed 2 months ago

DesertsP commented 2 months ago

It seems that OWL doesn't really apply gradient cropping? As below code shows, max_grad_norm is never be applied.

  def train_step(train_state, batch):

    def grad_fn(inputs):
      batch, rng = inputs['batch'], inputs['rng']
      def loss_fn(params):
        # Bind the rng to the host/device we are on.
        model_rng = train_utils.bind_rng_to_host_device(
            rng, axis_name='batch', bind_to='device')
        kwargs = {'text_queries': batch['queries']}

        predictions = flax_model.apply(
            {'params': params, **train_state.model_state},
            batch['inputs'],
            train=True,
            debug=debug,
            rngs={'dropout': model_rng},
            **kwargs)

        return loss_and_metrics_fn(predictions, batch, model_params=params)

      compute_gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
      (_, metrics), grad = compute_gradient_fn(train_state.params)

      # Note: zero-ing out frozen gradients changes the L2 norm. Clipping is
      # done inside Optax before zero-inng out frozen weights.
      grad = scenic_optax.replace_frozen(config.schedule, grad, 0.)
      metrics['l2_grads_orig'] = (utils.l2_norm(grad), 1)
      return grad, metrics

    if per_example_clipping and max_grad_norm is not None:
      # For per-example clipping we produce per-example rngs.
      rngs = jax.random.split(train_state.rng, num=batch['inputs'].shape[0] + 1)
      new_rng, model_rng = rngs[0], rngs[1:]
      # We add an additional dimension which wil serve as the batch dimension
      # for single examples when applying scan or vmap.
      batch = jax.tree_util.tree_map(lambda x: x[:, jnp.newaxis], batch)
      inp = {'batch': batch, 'rng': model_rng}
      grad, metrics = jax.vmap(grad_fn, 0)(inp)
    else:
      # Without per example clipping we can just compute the gradient on the
      # entire batch.
      new_rng, model_rng = jax.random.split(train_state.rng)
      grad, metrics = grad_fn({'batch': batch, 'rng': model_rng})

    new_train_state, g = update_fn(train_state, grad, new_rng)
    metrics['l2_grads'] = (utils.l2_norm(g), 1)
    metrics['l2_params'] = (utils.l2_norm(new_train_state.params), 1)
    return new_train_state, metrics

  return train_step
mjlm commented 2 months ago

The config.optimizer object is passed to the scenic_optax.make function, which applies gradient clipping here: https://github.com/google-research/scenic/blob/main/scenic/train_lib/optax.py#L220