patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

First step of `GradientDescent` optimizer is a no-op #82

Open eringrant opened 2 months ago

eringrant commented 2 months ago

It seems like the first call to step of the GradientDescent optimizer doesn't perform the step operation. I didn't check if this occurs for other optimizers or do other digging, but can do so if this is not expected behavior and the cause is not immediate. Here is a MWE:

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--use_optax", action="store_true")
args = parser.parse_args()

if args.use_optax:
  import optax
  optimizer = optx.OptaxMinimiser(optax.sgd(1e-1), rtol=1e-4, atol=1e-4)
else:
  optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)

N = K = 8

k1, k2 = jax.random.split(jax.random.PRNGKey(0))
w_star = jax.random.normal(k1, (K, N))
w_hat = jax.random.normal(k2, (K, N))

x = jnp.linspace(0, 1, N)[None, ...]
y = jnp.dot(w_star, x.T)

def loss(w, _):
  return jnp.mean((jnp.dot(w, x.T) - y) ** 2), None

options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.Partial(
  optimizer.init,
  args=None,
  fn=loss,
  options=options,
  f_struct=f_struct,
  aux_struct=aux_struct,
  tags=tags,
)
step = eqx.Partial(
  optimizer.step,
  args=None,
  fn=loss,
  options=options,
  tags=tags,
)

state = init(y=w_hat)
initial_loss = loss(w_hat, None)[0]
print(f"t = 0 | loss = {initial_loss}.")

w_hat, state, _ = step(y=w_hat, state=state)
one_step_loss = loss(w_hat, None)[0]
print(f"t = 1 | loss = {one_step_loss}.")

w_hat, state, _ = step(y=w_hat, state=state)
two_step_loss = loss(w_hat, None)[0]
print(f"t = 2 | loss = {two_step_loss}.")

if initial_loss == one_step_loss:
  raise ValueError("Loss did not decrease after one step of optimization.")

Running with GradientDescent gives:

$ python test.py
t = 0 | loss = 2.189293384552002.
t = 1 | loss = 2.189293384552002.
t = 2 | loss = 1.8877067565917969.
Traceback (most recent call last):
  File ".../test.py", line 68, in <module>
    raise ValueError("Loss did not decrease after one step of optimization.")
ValueError: Loss did not decrease after one step of optimization.

cf. OptaxMinimiser(optax.sgd(...), ...):

$ python test.py --use_optax
t = 0 | loss = 2.189293384552002.
t = 1 | loss = 1.8877067565917969.
t = 2 | loss = 1.6276657581329346.
patrick-kidger commented 2 months ago

This is expected... but admittedly maybe not great design.

The relevant code is here:

https://github.com/patrick-kidger/optimistix/blob/58348db56dd92e099eeb070f76c597424ebee34f/optimistix/_solver/gradient_methods.py#L156-L214

The way this works is that we actually treat general gradient methods, which typically start by picking a descent direction, and then performing a line search in that direction. Once the line search has found an acceptable point to stop, then this location is used to start a new line search.

In the case of GradientDescent, the line search is a single step of size corresponding the learning rate, and the result is always treated as acceptable. This means that the 'accepted' point is the start of the line search -- which is the previous iteration.

Off the top of my head I'm not sure how we'd change this. We might be able to tweak the logic in the above block of code to remove this off-by-one approach to things. (I'm open to suggestions on this one.)