keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.27k stars 19.38k forks source link

Docs for Mixed precision training and the use of `LossScaleOptimizer` in custom loops with jax backend #19244

Open VachanVY opened 4 months ago

VachanVY commented 4 months ago

Hi, Should we Scale the loss and gradients like in the tf.keras docs for jax backend or the LossScaleOptimizer takes care of it in float16 training? Tried just wrapping AdamW in it thinking it'll take care of scaling but the training doesn't improve after some steps (but model trained in float32 does train properly)

I checked Mixed precision policy API, examples, model.fit to see how it handles it but didn't understand and LossScaleOptimizer for mixed precision training in jax but did not find it. Please include the details in the docs. Thanks

SuryanarayanaY commented 4 months ago

Hi @VachanVY ,

Are you using custom training loop or simple model.fit. AFAIK without custom training loop model.fit should take care of it. With custom training we need to do some extra steps as mentioned for Tensorflow backend. It seems documentation for this not yet available. I will cross check again and will update.

Thanks!

VachanVY commented 4 months ago

@SuryanarayanaY I'm using a Fully Custom Training Loop. Could you please explain in brief about its usage?

Edit: from jax trainer https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/backend/jax/trainer.py#L68 https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/backend/jax/trainer.py#L106

SuryanarayanaY commented 3 months ago

Hi @VachanVY ,

Could you please submit a code snippet on how you have done custom training loop? May be it helps us. Thanks!

VachanVY commented 3 months ago

from my repo

@partial(jax.jit, static_argnums=-1)
def compute_loss(trainable_variables:list, non_trainable_variables:list, X_batch:Array, y_batch:Array, num_grad_accumalation_steps:int):
    logits, non_trainable_variables = model.stateless_call(
        trainable_variables,  non_trainable_variables,
        X_batch
    )
    loss = loss_fn(y_batch, logits)
    accuracy = get_accuracy(y_batch, logits)
    unscaled_loss = loss/num_grad_accumalation_steps
    scaled_loss = optimizer.scale_loss(unscaled_loss)
    return scaled_loss, (unscaled_loss, accuracy, non_trainable_variables)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)

@partial(jax.jit, static_argnums=-1)
def mini_step(train_state:Sequence[list], X_batch:Array, y_batch:Array, num_grad_accumalation_steps:int):
    trainable_variables, non_trainable_variables = train_state

    (_, aux), scaled_grad = grad_fn(
        trainable_variables, non_trainable_variables, X_batch, y_batch,
        num_grad_accumalation_steps
    )
    (unscaled_loss, accuracy, non_trainable_variables) = aux
    return scaled_grad, (unscaled_loss, accuracy), (trainable_variables, non_trainable_variables)

@jax.jit
def update_params(grads:list, trainable_variables:list, optimizer_variables:list):
    trainable_variables, optimizer_variables = optimizer.stateless_apply( # also handles scaled grads if LossScalerOPtimizer is used
        optimizer_variables, grads, trainable_variables
    ) # returns updated trainable_variables
    return trainable_variables, optimizer_variables