keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.87k stars 19.44k forks source link

How to implement train_step with multiple gradient calculations with a JAX backend? Ex. GAN #18881

Open craig-sony opened 10 months ago

craig-sony commented 10 months ago

I have been trying to figure out how to write a GAN using Keras 3 with a JAX backend using the stateless_call API. I cannot figure out a clean way to deal with the need to have separate gradients computed for the discriminator and generator. The only approach I've gotten close with is to create a mapping between the trainable_variables/non_trainable_variables lists used by the model and the corresponding layers. Then when I call stateless_call I first have to extract the corresponding trainable_variables/non_trainable_variables for the layer being called from those passed into the train_step function, but then I need to reinsert the non_trainable_variables returned by stateless_call. It's a mess.

Can you please update the following example for Keras 3? https://keras.io/examples/generative/conditional_gan/

Thanks.

fchollet commented 10 months ago

I think you can use a StatelessScope and then just write a stateful train_step, which is 10x easier.

Probably something like

def train_step(self, state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metrics_variables,
    ) = state
    grad_fn_gen = jax.value_and_grad(self.compute_loss_and_updates_gen, has_aux=True)
    grad_fn_disc = jax.value_and_grad(self.compute_loss_and_updates_disc, has_aux=True)
    state_mapping = list(zip(self.trainable_variables, trainable_variables)) + list(zip(self.non_trainable_variables, non_trainable_variables))
    with keras.StatelessScope(state_mapping) as scope:
       (loss_gen, (y_pred_gen, non_trainable_variables_gen)), grads = grad_fn_gen(
           self.gen.trainable_variables,
           self.gen.non_trainable_variables,
           gen_x,
           gen_y,
           training=True,
       )
      (loss_disc, (y_pred_disc, non_trainable_variables_disc)), grads = grad_fn_disc(
           self.disc.trainable_variables,
           self.disc.non_trainable_variables,
           disc_x,
           disc_y,
           training=True,
       )
       ...
   trainable_variables = [scope.get_current_value(w)] for w in self.trainable_variables]
   non_trainable_variables = [scope.get_current_value(w) for w in self.non_trainable_variables]

You get the idea. Just set variable values with the scope and then you can use self.gen.variables, etc. At the scope exit you collect back the updated variable values and you return those.

fchollet commented 10 months ago

For a real-world example see how we handle stateful metrics in the JAX backend: https://github.com/keras-team/keras/blob/master/keras/backend/jax/trainer.py#L130-L145

In general, working with JAX statelessness is pretty terrible, so the solution is to open a StatelessScope and pretend everything is stateful 👍

github-actions[bot] commented 10 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

craig-sony commented 10 months ago

Thanks for the info, any chance of getting the example updated though?