google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Passing arguments to train multiple models in parallel #932

Open kclauw opened 2 months ago

kclauw commented 2 months ago

Hi,

I want to perform a gridsearch over different arguments to train multiple models in parallel using optax and flax. My initial idea is to pass an array of learning rates to an initialization function using vmap but it results in a side effect transformation error.

What is the best way to pass a list of arguments and can this be solved? The issue seems to be related to the adamw optimizer which I believe modifies the learning rate parameter?

I have attached a reduced example of my code:


def calculate_loss_acc(state, params, batch):
    data_input, labels = batch
    logits = state.apply_fn(params, data_input)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    acc = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, acc

@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function

    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc

def initialization(model, learning_rate, input_size, seed):
    rng = jax.random.PRNGKey(seed)

    rng, init_rng = jax.random.split(rng)
    dummy_input = jax.random.normal(init_rng, (8, input_size))  # Batch size 8, input size 2
    params = model.init(init_rng, dummy_input)
    model.apply(params, dummy_input)
    optimizer = optax.adamw(learning_rate=learning_rate)
    model_state = train_state.TrainState.create(apply_fn=model.apply,
                                                params=params,
                                                tx=optimizer)
    return model_state

@hydra.main(version_base=None, config_name="main", config_path="config")
def main(cfg) -> None:
    seed = 0
    num_epochs = 1
    input_size = 194
    output_size = 97
    learning_rates = jnp.array([0.01, 0.1])

    train_dataloader, test_dataloader = get_dataloaders(cfg)
    model = FCNN_2(num_hidden=1000, 
                   num_outputs=output_size, 
                   activation = cfg.model.parameters.activation)

    parallel_init_fn = jax.vmap(initialization, in_axes=(None, 0, None, None))
    parallel_train_step_fn = jax.vmap(train_step, in_axes=(0, None))

    params = parallel_init_fn(model, learning_rates, input_size, seed)

    for epoch in range(num_epochs):
        #Run training on epoch
        for batch in train_dataloader:
            params, loss, acc = parallel_train_step_fn(params, batch)
            print(loss)
vroulet commented 2 months ago

Hello @kclauw,

  1. What error do you get exactly?
  2. Why are you saying that the issue is with adamw? Adamw does not modify the learning rate internally. Have you tried with sgd and did that produce the same error?

Thanks for reaching out

kclauw commented 2 months ago

Hi,

Thanks 1) I am still learning Jax coming from Pytorch but my understanding of the error is that something is changing the value of the learning rate parameter in the initialization function:


params, loss, acc = parallel_train_step_fn(params, batch)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[], weak_type=True)>with<BatchTrace(level=1/0)> with
  val = Array([0, 0], dtype=int32, weak_type=True)
  batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

2) The code works when using a fixed value. However, when using the learning rate passed by vmap it gives the error. Changing to SGD did not resolve this issue. Based on this, I figured the optax optimizer might be changing the learning rate. 3) I passed a list of seeds as argument to initialization which is not used by the optimizer. This works fine so the issue seems to only happen when passing the learning rate parameter in combination with any optix optimizer.

I looked at the code of adamw:


def adamw(
    learning_rate: base.ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
    weight_decay: float = 1e-4,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
    *,
    nesterov: bool = False,
) -> base.GradientTransformation:
  return combine.chain(
      transform.scale_by_adam(
          b1=b1,
          b2=b2,
          eps=eps,
          eps_root=eps_root,
          mu_dtype=mu_dtype,
          nesterov=nesterov,
      ),
      transform.add_decayed_weights(weight_decay, mask),
      transform.scale_by_learning_rate(learning_rate),
  )

def scale_by_learning_rate(
    learning_rate: base.ScalarOrSchedule,
    *,
    flip_sign: bool = True,
) -> base.GradientTransformation:
  m = -1 if flip_sign else 1
  if callable(learning_rate):
    return scale_by_schedule(lambda count: m * learning_rate(count))
  return scale(m * learning_rate)

The problem is due to adamw (and SGD etc) changing the learning rate via transform.scale_by_learning_rate(learning_rate) see (scale(m * learning_rate).

What would be the best way to deal with having to pass arguments that will change during vmap? if this is even possible? I figure this will also become a problem when passing weight decay arguments.

Ekundayo39283 commented 2 months ago

When dealing with parameters that change during vmap, like learning rates or weight decay values, you can use partial function application or closures. This allows you to fix certain arguments while leaving others flexible. For instance, you can create a function that takes only the parameters that remain constant during vmap, then partially apply it with the varying parameters within the vmap loop. This ensures that only the necessary parameters are passed through vmap, avoiding unexpected tracer errors.

vroulet commented 2 months ago

Hello @kclauw,

Sorry for the delayed answer.

  1. It could help if you would make the example minimal to reproduce the same error (some dependencies are not defined in what you sent). Also you may try to trace the error as suggested just to be sure. It's not clear to me yet if this is really the learning rate that is the culprit here.
  2. If this is truly the learning rate, one quick workaround would be to use optax.inject_hyperparams. So you would instantiate the optimizer as opt= optax.inject_hyperparams(optax.adamw)(learning_rate=1.) outside the vmap and in the vmap you would call the init of the optimizer state = opt.init(params). In the resulting state, you would be able to change the learning rate chosen state = optax.tree_util.tree_set(state, learning_rate=your_learning_rate). The optimizer would then run with the learning rate you chose in the vmap. Happy to try out to be sure but I'd need a minimal example for that.