google / evojax

Apache License 2.0
834 stars 85 forks source link

Bug of center_lr_decay_steps when use adam with PGPE #25

Open garam-kim1 opened 2 years ago

garam-kim1 commented 2 years ago

Bug

When use adam with PGPE this code

self._opt_state = self._opt_update(
            self._t // self._lr_decay_steps, -grad_center, self._opt_state
        )

means adam t will increase after every self._lr_decay_steps. And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)

def update(i, g, state):
    x, m, v = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * jnp.square(g) + b2 * v  # Second moment estimate.
    mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
    vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
    x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
    return x, m, v

Suggestion

I think it is better to change this code to

step_size=lambda x: self._center_lr * jnp.power(decay_coef, x // self._lr_decay_steps),

and to remove self._lr_decay_steps at

self._opt_state = self._opt_update(
            self._t, -grad_center, self._opt_state
        )
dietmarwo commented 1 year ago

Unfortunately we have the default

decay_coef = optimizer_config.get("center_lr_decay_coef", 1.0)

effectively deactivating decay_coef and not a single config in PGPE configs changes this.

So there is no easy way to check if the proposed change is a regression. But when I created the C++ PGPE implementation wrapped in fpgpec.py I implemented the C++ code exactly as you propose now, see: pgpe.cpp.

Do you have a benchmark problem actually using decay_coef?