mbarbetti / pidgan

:package: GAN-based models to flash-simulate the LHCb PID detectors
GNU General Public License v3.0
3 stars 1 forks source link

Enabling JAX as backend for the GAN training step #8

Open mbarbetti opened 5 days ago

mbarbetti commented 5 days ago

Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.

Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities.

At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have

def train_step(self, *args, **kwargs):
  if keras.backend.backend() == "tensorflow":
    return self._tf_train_step(*args, **kwargs)
  elif keras.backend.backend() == "torch":
     raise NotImplementedError("`train_step()` not implemented for the PyTorch backend")
  elif keras.backend.backend() == "jax":
     raise NotImplementedError("`train_step()` not implemented for the Jax backend")

The goal of this issue is to implement the train_step() also for the JAX backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the JAX backend.