google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

Stopping condition 'madsen-nielsen' incorrect #575

Open Joshuaalbert opened 5 months ago

Joshuaalbert commented 5 months ago

The documentation says a different thing than code. Specifically, the - is inconsistent with the + in docstring at this part (tree_l2_norm(params) - self.xtol).

Docstring says:

the convergence is achieved once the
coeff update satisfies ``||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or
the gradient satisfies ``||grad(f)||_inf <= gtol``.

Code says:

      tree_mul_term = self.xtol * (tree_l2_norm(params) - self.xtol)
      return jnp.all(jnp.array([
        tree_inf_norm(state.gradient) > self.gtol,
        tree_l2_norm(state.delta) > tree_mul_term
      ]))

Additionally, rather than all(array(...)) you should use jnp.bitwise_and(..., ...) or | & and ~ ops.

My suggestion

Upon reading up about madsen-nielsen stopping condition it seems that there is no single version of it. From my optimisation work I find incorporating both absolute and relative tolerance in parameter changes is quite useful. (Currently it looks like it's only relative)

def leaves_vec(tree_x):
  return jnp.concatenate(tree_leaves(tree_map(jnp.ravel, tree_x)))

atol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.atol)
rtol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.rtol * jnp.abs(tree_vec(params)))
grad_cond = jnp.max(jnp.abs(leaves_vec(state.gradient))) <= self.gtol
done = atol_cond | rtol_cond | grad_cond
return ~done

# defaults
atol = 0. # effectively turned off unless user wants it on, to be backward compatible with current.
rtol = 1e-3
gtol = 1e-3