The documentation says a different thing than code. Specifically, the - is inconsistent with the + in docstring at this part (tree_l2_norm(params) - self.xtol).
Docstring says:
the convergence is achieved once the
coeff update satisfies ``||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or
the gradient satisfies ``||grad(f)||_inf <= gtol``.
Additionally, rather than all(array(...)) you should use jnp.bitwise_and(..., ...) or | & and ~ ops.
My suggestion
Upon reading up about madsen-nielsen stopping condition it seems that there is no single version of it. From my optimisation work I find incorporating both absolute and relative tolerance in parameter changes is quite useful. (Currently it looks like it's only relative)
def leaves_vec(tree_x):
return jnp.concatenate(tree_leaves(tree_map(jnp.ravel, tree_x)))
atol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.atol)
rtol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.rtol * jnp.abs(tree_vec(params)))
grad_cond = jnp.max(jnp.abs(leaves_vec(state.gradient))) <= self.gtol
done = atol_cond | rtol_cond | grad_cond
return ~done
# defaults
atol = 0. # effectively turned off unless user wants it on, to be backward compatible with current.
rtol = 1e-3
gtol = 1e-3
The documentation says a different thing than code. Specifically, the
-
is inconsistent with the+
in docstring at this part(tree_l2_norm(params) - self.xtol)
.Docstring says:
Code says:
Additionally, rather than
all(array(...))
you should usejnp.bitwise_and(..., ...)
or | & and ~ ops.My suggestion
Upon reading up about
madsen-nielsen
stopping condition it seems that there is no single version of it. From my optimisation work I find incorporating both absolute and relative tolerance in parameter changes is quite useful. (Currently it looks like it's only relative)