ageron / handson-ml2

A series of Jupyter notebooks that walk you through the fundamentals of Machine Learning and Deep Learning in Python using Scikit-Learn, Keras and TensorFlow 2.
Apache License 2.0
27.8k stars 12.74k forks source link

[QUESTION] Why use `keras.backend` functions instead of `tf` functions? #591

Open rmurphy2718 opened 2 years ago

rmurphy2718 commented 2 years ago

In the book and notebooks, functions from keras.backend are sometimes used instead of tf functions. For example, we might see K = keras.backend followed by K.mean instead of tf.reduce_mean. I am asking generally, but an example can be seen in 19_training_and_deploying_at_scale.ipynb

# ...
def train_step():
    def step_fn(inputs):
        X, y = inputs
        with tf.GradientTape() as tape:
            Y_proba = model(X)
            loss = K.sum(keras.losses.sparse_categorical_crossentropy(y, Y_proba)) / batch_size

# ...
  1. Why is this choice made?
  2. Are there important reasons to use keras.backend functions instead of tf reduce functions like this?
ageron commented 2 years ago

Hi @rmurphy2718 ,

Thanks for your question.

In the 3rd edition (coming out in October), I no longer use keras.backend. You can see the code in ageron/handson-ml3.

The reason I used it in the second edition is that I wanted the code to be as portable as possible to other Keras implementations, such as those based on Theano or MXNet rather than TensorFlow. Sadly, these other implementations no longer exist, and Keras is currently TensorFlow-only. However, there are discussions about creating a JAX-based implementation of Keras, but AFAIK it's not started yet.

As long as TensorFlow is the only backend for Keras, there's no much point in using keras.backend instead of calling the TensorFlow functions directly.

Hope this helps.