keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.64k stars 19.42k forks source link

Tensorflow is still imported when using jax or torch backend #18455

Open ageron opened 1 year ago

ageron commented 1 year ago

When I import keras_core, it imports TensorFlow even when I set the backend to jax or torch:

>>> import os
>>> os.environ["KERAS_BACKEND"] = "jax"
>>> import keras_core
2023-06-12 11:14:46.431809: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Using JAX backend.
>>> keras_core.backend.backend()
'jax'

Since TensorFlow takes up to 3-4 seconds to load on my machine, so it would be nice to avoid that. And of course it would be nice not to have to install it when using another backend since it's quite a big beast and uses a lot of disk space.

fchollet commented 1 year ago

I agree that we need to remove the TF required dependency. There are several reasons why this is not possible right now:

Eventually we'll make TF optional, but it will take some time.

vulkomilev commented 1 year ago

is tf.nest still used to process nested Python structures ?

mehtamansi29 commented 6 days ago

Hi @ageron -

In keras3, for deeply nested inputs in functional models no need to use tf.nest. You can directly apply dictionary input or nested dictionary(more than 1 level) also applied as input to model. Here you can find more detail about it.

inputs = {
    "foo": keras.Input(shape=(1,), name="foo"),
    "bar": {
        "baz": keras.Input(shape=(1,), name="bar"),
    },
}
outputs = inputs["foo"] + inputs["bar"]["baz"]
keras.Model(inputs, outputs)

This nested input works fine with JAX and torch backend as well. Attached gist for your reference.