Open VachanVY opened 4 months ago
Hi @VachanVY ,
Are you using custom training loop or simple model.fit. AFAIK without custom training loop model.fit should take care of it. With custom training we need to do some extra steps as mentioned for Tensorflow backend. It seems documentation for this not yet available. I will cross check again and will update.
Thanks!
@SuryanarayanaY I'm using a Fully Custom Training Loop. Could you please explain in brief about its usage?
Edit: from jax trainer https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/backend/jax/trainer.py#L68 https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/backend/jax/trainer.py#L106
Hi @VachanVY ,
Could you please submit a code snippet on how you have done custom training loop? May be it helps us. Thanks!
@partial(jax.jit, static_argnums=-1)
def compute_loss(trainable_variables:list, non_trainable_variables:list, X_batch:Array, y_batch:Array, num_grad_accumalation_steps:int):
logits, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables,
X_batch
)
loss = loss_fn(y_batch, logits)
accuracy = get_accuracy(y_batch, logits)
unscaled_loss = loss/num_grad_accumalation_steps
scaled_loss = optimizer.scale_loss(unscaled_loss)
return scaled_loss, (unscaled_loss, accuracy, non_trainable_variables)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
@partial(jax.jit, static_argnums=-1)
def mini_step(train_state:Sequence[list], X_batch:Array, y_batch:Array, num_grad_accumalation_steps:int):
trainable_variables, non_trainable_variables = train_state
(_, aux), scaled_grad = grad_fn(
trainable_variables, non_trainable_variables, X_batch, y_batch,
num_grad_accumalation_steps
)
(unscaled_loss, accuracy, non_trainable_variables) = aux
return scaled_grad, (unscaled_loss, accuracy), (trainable_variables, non_trainable_variables)
@jax.jit
def update_params(grads:list, trainable_variables:list, optimizer_variables:list):
trainable_variables, optimizer_variables = optimizer.stateless_apply( # also handles scaled grads if LossScalerOPtimizer is used
optimizer_variables, grads, trainable_variables
) # returns updated trainable_variables
return trainable_variables, optimizer_variables
Hi, Should we Scale the loss and gradients like in the tf.keras docs for jax backend or the
LossScaleOptimizer
takes care of it infloat16
training? Tried just wrapping AdamW in it thinking it'll take care of scaling but the training doesn't improve after some steps (but model trained infloat32
does train properly)I checked Mixed precision policy API, examples, model.fit to see how it handles it but didn't understand and
LossScaleOptimizer
for mixed precision training in jax but did not find it. Please include the details in the docs. Thanks