It seems that OWL doesn't really apply gradient cropping?
As below code shows, max_grad_norm is never be applied.
def train_step(train_state, batch):
def grad_fn(inputs):
batch, rng = inputs['batch'], inputs['rng']
def loss_fn(params):
# Bind the rng to the host/device we are on.
model_rng = train_utils.bind_rng_to_host_device(
rng, axis_name='batch', bind_to='device')
kwargs = {'text_queries': batch['queries']}
predictions = flax_model.apply(
{'params': params, **train_state.model_state},
batch['inputs'],
train=True,
debug=debug,
rngs={'dropout': model_rng},
**kwargs)
return loss_and_metrics_fn(predictions, batch, model_params=params)
compute_gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, metrics), grad = compute_gradient_fn(train_state.params)
# Note: zero-ing out frozen gradients changes the L2 norm. Clipping is
# done inside Optax before zero-inng out frozen weights.
grad = scenic_optax.replace_frozen(config.schedule, grad, 0.)
metrics['l2_grads_orig'] = (utils.l2_norm(grad), 1)
return grad, metrics
if per_example_clipping and max_grad_norm is not None:
# For per-example clipping we produce per-example rngs.
rngs = jax.random.split(train_state.rng, num=batch['inputs'].shape[0] + 1)
new_rng, model_rng = rngs[0], rngs[1:]
# We add an additional dimension which wil serve as the batch dimension
# for single examples when applying scan or vmap.
batch = jax.tree_util.tree_map(lambda x: x[:, jnp.newaxis], batch)
inp = {'batch': batch, 'rng': model_rng}
grad, metrics = jax.vmap(grad_fn, 0)(inp)
else:
# Without per example clipping we can just compute the gradient on the
# entire batch.
new_rng, model_rng = jax.random.split(train_state.rng)
grad, metrics = grad_fn({'batch': batch, 'rng': model_rng})
new_train_state, g = update_fn(train_state, grad, new_rng)
metrics['l2_grads'] = (utils.l2_norm(g), 1)
metrics['l2_params'] = (utils.l2_norm(new_train_state.params), 1)
return new_train_state, metrics
return train_step
It seems that OWL doesn't really apply gradient cropping? As below code shows, max_grad_norm is never be applied.