Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.
Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities.
At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have
def train_step(self, *args, **kwargs):
if keras.backend.backend() == "tensorflow":
return self._tf_train_step(*args, **kwargs)
elif keras.backend.backend() == "torch":
raise NotImplementedError("`train_step()` not implemented for the PyTorch backend")
elif keras.backend.backend() == "jax":
raise NotImplementedError("`train_step()` not implemented for the Jax backend")
The goal of this issue is to implement the train_step() also for the JAX backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the JAX backend.
Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.
At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have
The goal of this issue is to implement the
train_step()
also for the JAX backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the JAX backend.