google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.63k stars 174 forks source link

incompatibility between zero_nans() and MultiSteps #828

Closed TheMr33 closed 6 months ago

TheMr33 commented 6 months ago

Hello,

There seems to be an incompatibility when using optax.zero_nans() with optax.MultiSteps.

I replicate my issue starting from the gradient_accumulation example notebook:

import functools
from typing import Callable, Iterable, Tuple, TypedDict

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import chex

class MiniBatch(TypedDict):
  image: jnp.ndarray
  label: jnp.ndarray

UpdateFn = Callable[[hk.Params, optax.OptState, MiniBatch],
                    Tuple[hk.Params, optax.OptState]]

@hk.transform
def net(image: jnp.ndarray) -> jnp.ndarray:
  """A Haiku parameterized function, based on an MLP."""
  features = image.reshape((image.shape[0], -1))
  return hk.nets.MLP([32, 32, 10])(features)

def loss_fn(params: hk.Params, batch: MiniBatch) -> jnp.ndarray:
  """Computes softmax cross entropy for the net outputs batch."""
  logits = net.apply(params, jax.random.PRNGKey(0), batch['image'])
  return optax.softmax_cross_entropy_with_integer_labels(
      logits, batch['label']).mean()

def build_update_fn(optimizer: optax.GradientTransformation) -> UpdateFn:
  """Builds a function for executing a single step in the optimization."""

  @jax.jit
  def update(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

  return update

def fit(
    optimizer: optax.GradientTransformation,
    params: hk.Params,
    batches: Iterable[MiniBatch],
) -> hk.Params:
  """Executes a train loop over the train batches using the given optimizer."""

  update_fn = build_update_fn(optimizer)
  opt_state = optimizer.init(params)

  for batch in batches:
    params, opt_state = update_fn(params, opt_state, batch)

  return params

EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1))
LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10)

optimizer = optax.sgd(1e-4)  # <=== baseline optimizer
optimizer_nan = optax.chain(optax.zero_nans(), optimizer) # <=== failing optimizer

params = net.init(jax.random.PRNGKey(0), EXAMPLES)

# ------------------------------------------------------------------
# ------------------------------------------------------------------
# optimizer and optimizer_nan are consistant

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

new_params_single_batch_nan = fit(
    optimizer_nan,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_single_batch_nan,
    atol=1e-7,
)

# ------------------------------------------------------------------
# ------------------------------------------------------------------
# optax.MultiSteps works with optimizer, but not with optimizer_nan

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

new_params_gradient_accumulation_nan = fit(
    optax.MultiSteps(optimizer_nan, every_k_schedule=3),  # <== crashes here
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_gradient_accumulation,
    new_params_gradient_accumulation_nan,
    atol=1e-7,
)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 new_params_gradient_accumulation_nan = fit(
      2     optax.MultiSteps(optimizer_nan, every_k_schedule=3),
      3     params,
      4     batches=[
      5         MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
      6         MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
      7         MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
      8     ],
      9 )

Cell In[3], line 25, in fit(optimizer, params, batches)
     22 opt_state = optimizer.init(params)
     24 for batch in batches:
---> 25   params, opt_state = update_fn(params, opt_state, batch)
     27 return params

    [... skipping hidden 12 frame]

Cell In[3], line 7, in build_update_fn.<locals>.update(params, opt_state, batch)
      4 @jax.jit
      5 def update(params, opt_state, batch):
      6   loss, grads = jax.value_and_grad(loss_fn)(params, batch)
----> 7   updates, opt_state = optimizer.update(grads, opt_state)
      8   params = optax.apply_updates(params, updates)
      9   return params, opt_state

File /media/thierry/Develop/optax/optax/_src/wrappers.py:435, in MultiSteps.update(self, updates, state, params, **extra_args)
    432   zero_updates = _zeros_tree_like(state.acc_grads)
    433   return zero_updates, multi_state_when_skip
--> 435 new_updates, new_state = jax.lax.cond(
    436     should_skip_update, _skip_update, _do_update, *(updates, state, params)
    437 )
    438 return new_updates, new_state

    [... skipping hidden 3 frame]

File /media/thierry/Develop/optax/venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py:215, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
    212 if not all(map(core.typematch, avals1, avals2)):
    213   diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
    214                   tree_unflatten(tree2, avals2))
--> 215   raise TypeError(f"{what} must have identical types, got\n{diff}.")

TypeError: true_fun and false_fun output must have identical types, got
({'mlp/~/linear_0': {'b': 'ShapedArray(float32[32])', 'w': 'ShapedArray(float32[784,32])'}, 'mlp/~/linear_1': {'b': 'ShapedArray(float32[32])', 'w': 'ShapedArray(float32[32,32])'}, 'mlp/~/linear_2': {'b': 'ShapedArray(float32[10])', 'w': 'ShapedArray(float32[32,10])'}}, MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(ZeroNansState(found_nan={'mlp/~/linear_0': {'b': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)', 'w': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)'}, 'mlp/~/linear_1': {'b': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)', 'w': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)'}, 'mlp/~/linear_2': {'b': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)', 'w': 'DIFFERENT ShapedArray(bool[]) vs. ShapedArray(int32[], weak_type=True)'}}), (EmptyState(), EmptyState())), acc_grads={'mlp/~/linear_0': {'b': 'ShapedArray(float32[32])', 'w': 'ShapedArray(float32[784,32])'}, 'mlp/~/linear_1': {'b': 'ShapedArray(float32[32])', 'w': 'ShapedArray(float32[32,32])'}, 'mlp/~/linear_2': {'b': 'ShapedArray(float32[10])', 'w': 'ShapedArray(float32[32,10])'}}, skip_state=())).
fabianp commented 6 months ago

I think the issue is that MultiStep is trying to average the OptState (https://github.com/google-deepmind/optax/blob/2e92d570c784de4b357dce83180021e658a9210f/optax/_src/wrappers.py#L407) , but in the case optax.zero_nans() has as a state a boolean mask which gets converted to integers with the averaging operation

fabianp commented 6 months ago

one potential solution would be to replace lambda st, nst: (1 - emit) * st + emit * nst, with lambda st, nst: ((1 - emit) * st + emit * nst).astype(st.dtype), to make sure the dtype doesn't change, but I wonder if that would generate other unintended consequences.

In particular, I think this solution will fail if the leafs are scalars instead of arrays

fabianp commented 6 months ago

Ah, since emit is a boolean, I think there's a better solution.

@TheMr33 can you confirm that #840 solves your issue?

TheMr33 commented 6 months ago

Hi @fabianp, thank you for the answer.

Your first answer (with the lambda) and your latest commit (6de95bf) both seem to solve my problem.

I haven't checked the consistency of the gradient, but the code no longer crashes.

fabianp commented 6 months ago

excellent, thanks! Closing this one :-)