keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
64 stars 32 forks source link

Cannot Build Intermediate Model to Nested Layers #643

Open weidler opened 2 years ago

weidler commented 2 years ago

System information.

Describe the problem. Note that I have previously reported this issue here for TF2.0. Back then the tensorflow team suggested a solution that worked under 2.0 but now does not work anymore.

Here is the problem: Using the functional API one can build an intermediate model starting and ending at any of the original models layers. This however does not work when layers are encapsulated in an inner model (lets say, some tf.keras.Sequential). The graph will differ due to the additional Input layer, but the computations should be the same. However, when trying to build intermediate model of a nested model up to an inner layer, a "Graph disconnected" error is thrown (see below). Previously, one could circumvent this by not building to final_model.get_layer("inner_model").get_layer("id_1").output but final_model.get_layer("inner_model").get_layer("id_1").get_output_at(1) (full example see below).

Standalone code to reproduce the issue.

import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# NOT NESTED
inp = tf.keras.Input((4,))
y = tf.keras.layers.Dense(4, name="od_1")(inp)
y = tf.keras.layers.Dense(2, name="od_2")(y)
y = tf.keras.layers.Dense(4, name="id_1")(y)
y = tf.keras.layers.Dense(10, name="od_3")(y)
y = tf.keras.layers.Dense(10, name="od_4")(y)
final_model = tf.keras.Model(inputs=[inp], outputs=[y])
final_model.summary()

sub_model = tf.keras.Model(inputs=[final_model.input], outputs=[final_model.get_layer("id_1").output])
sub_model.summary()

# NESTED
inp_1 = tf.keras.Input(shape=(2,))
x = tf.keras.layers.Dense(4, name="id_1")(inp_1)
inner_model = tf.keras.Model(inputs=[inp_1], outputs=[x], name="inner_model")

inp_outer = tf.keras.Input((4,))
y = tf.keras.layers.Dense(4, name="od_1")(inp_outer)
y = tf.keras.layers.Dense(2, name="od_2")(y)
y = inner_model(y)
y = tf.keras.layers.Dense(10, name="od_3")(y)
y = tf.keras.layers.Dense(10, name="od_4")(y)
final_model = tf.keras.Model(inputs=[inp_outer], outputs=[y])
final_model.summary()

sub_model = tf.keras.Model(inputs=[final_model.input], outputs=[final_model.get_layer("inner_model").get_layer("id_1").output])
previously_working_sub_model = tf.keras.Model(
    inputs=[final_model.input],
    outputs=[final_model.get_layer("inner_model").get_layer("id_1").get_output_at(1)])

This throws ValueError: Asked to get output at node 1, but the layer has only 1 inbound nodes. whereas only the sub_model line throws ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 2), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'") at layer "id_1". The following previous layers were accessed without issue: []

Expected behavior. To allow for accessing intermediate activations, it is crucial to be able to build intermediate models to (and preferably from) anywhere within the model.

weidler commented 2 years ago

@qlzh727 as you suggested a workaround for this issue previously, are you maybe aware of a new solution since this workaround got patched out?

sushreebarsa commented 2 years ago

@jvishnuvardhan I was able to reproduce this issue on colab using TF v2.8.0 and tf-nightly ,please find the gist here.Thanks!

isaacgerg commented 2 years ago

@weidler You might be able to get what you want by using .call() syntax and some cleverness.

For example, suppose you have a model which contains some pre-processing steps (e.g. preprocess_input) and a pretrained model. Calling model.summary() will show efficientnetb7 (or whatever your pre-trained model is) but not expand it. Furthermore, you want to access a layer in efficientnetb7. Here's what you can do.

  1. Create a submodel of your efficientnetb7 with the output you want.
  2. Create a prefix model which has just the processing.
  3. Stitch them together using .call() syntax.

Here's an example.

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 256, 256, 3)  0           input_1[0][0]                    
                                                                 input_1[0][0]                    
                                                                 input_1[0][0]                    
__________________________________________________________________________________________________
tf.math.multiply (TFOpLambda)   (None, 256, 256, 3)  0           concatenate[0][0]                
__________________________________________________________________________________________________
efficientnetb7 (Functional)     (None, 8, 8, 2560)   64097687    tf.math.multiply[0][0]           
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2560)         0           efficientnetb7[0][0]             
__________________________________________________________________________________________________
dense (Dense)                   (None, 1)            2561        global_average_pooling2d[0][0]   
==================================================================================================
Total params: 64,100,248
Trainable params: 63,789,521
Non-trainable params: 310,727
__________________________________________________________________________________________________

Suppose I want the output of "stem_conv" layer inside of efficientnetb7. I can obtain this by doing:

pre_model = tf.keras.models.Model(model.input, model.layers[2].output)
eff_net = tf.keras.models.Model(model.layers[3].input, model.layers[3].get_layer('stem_conv').output)
tmp_model = tf.keras.models.Model(pre_model.input, eff_net.call(pre_model.output))

Now, when we look at tmp_model, it has everything we want....

tmp_model.summary()
Model: "model_10"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 256, 256, 3)  0           input_1[0][0]                    
                                                                 input_1[0][0]                    
                                                                 input_1[0][0]                    
__________________________________________________________________________________________________
tf.math.multiply (TFOpLambda)   (None, 256, 256, 3)  0           concatenate[0][0]                
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 256, 256, 3)  0           tf.math.multiply[0][0]           
__________________________________________________________________________________________________
normalization (Normalization)   (None, 256, 256, 3)  7           rescaling[4][0]                  
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 257, 257, 3)  0           normalization[4][0]              
__________________________________________________________________________________________________
stem_conv (Conv2D)              (None, 128, 128, 64) 1728        stem_conv_pad[4][0]              
==================================================================================================
Total params: 1,735
Trainable params: 1,728
Non-trainable params: 7
__________________________________________________________________________________________________
weidler commented 2 years ago

Thanks for the suggestion @isaacgerg. In essence this is rebuilding the model with the functional API. Since I need to implement it for any general model possible though, its sadly not that straight forward. For many use cases this is a good solution though.

Since the application of building intermediate models for me was to get intermediate outputs, I instead turned to this solution (linking it here for others searching for a solution to that specific application).

ricvo commented 2 years ago

Just tested on 2.9.1

Let's say I have a composed model made of two model components, model1 and model2

Now I would like to get a feature model extracting the output of the 3rd layer of the 2nd model, starting from the input. Based on @isaacgerg suggestion, I would expect the following code to work.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Activation
from tensorflow.keras.models import Model, Sequential

model1 = Sequential([
    Conv2D(2, (3, 3), padding='same'),
    Conv2D(2, (3, 3), padding='same'),
    Conv2D(2, (3, 3), padding='same'),
    Conv2D(2, (3, 3), padding='same'),
    Activation('relu')
])

model2 = Sequential([
    Conv2D(3, (3, 3), padding='same'),
    Conv2D(3, (3, 3), padding='same'),
    Conv2D(3, (3, 3), padding='same'),
    Conv2D(3, (3, 3), padding='same'),
    Activation('softmax')
])

inp = tf.keras.Input((32, 32, 1))
x = model1(inp)
x = model2(x)

# this works
joint_model = Model(inputs=inp, outputs=x)

joint_model(inp)

# the pre model is until the output of layer[1], which is the first model at this point
pre_models = tf.keras.models.Model(joint_model.input, joint_model.layers[1].output))

But instead I get *** ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 4, 32, 32, 1), dtype=tf.float32, name='conv2d_input'), name='conv2d_input', description="created by layer 'conv2d_input'") at layer "conv2d". The following previous layers were accessed without issue: []

I am trying to get from the joint model the output of the first model, as pre model, then I would as well proceed to get the output of the second convolution inside the second model.. Am I misunderstanding something? Because otherwise I would have to explicitly rebuild all the layers in model1 and model2 up until the output I want, but sadly this does not generalize easily to a chosen intermediate output.

Would you have any suggestion as a workaround?

Edit 29/9/2022: changing a shape to have one batch dimension as more common.

gowthamkpr commented 2 years ago

@ricvo You can visualize the graph with tf.keras.utils.plot_model(model, show_shapes=True, show_dtype=True). It could help to find disconnected parts of the graph or errors in the model architecture.

ricvo commented 2 years ago

Hi @gowthamkpr thanks for your answer. Actually my point is that the graph should not be disconnected in this situation. Please give a look at the code I posted above. joint_model is the composition of model1 and model2. In fact I do understand that the symbolic output of model2 aka joint_model.layers[1] is connected only to the input of the internal model, in fact the following works

tf.keras.models.Model(joint_model.layers[1].input, joint_model.layers[1].output)

but it's also useless in the situation above. The fact that inner models are built separately in their own graph under the hood, forces me to rebuild each layer or model one by one until the node of interest. Fine, but not ideal really. It seems that models inside other models are built as tf.function separately which makes hard to get intermediate results.

One solution would be to forbid the use of models as layers of other models, but I don't think this is satisfactory, there are often situations in which you have an architecture which you use stand alone (hence you define it as a Model) and sometimes you would like also to compose it with other pieces to create a new Model. I am not sure about the solution, maybe tf.Models inside other tf.Models should build their computational graph inside the outer tf.Model? Behaving as if they were tf.Layers in this respect?

In tf1 the same thing was very easy because if you were using a single computational graph then you just had to get the node you wanted from the graph with the node path and then sess.run(node, input). I believe this should be equally easy in tf2. The fact that the code above is not working to get subgraphs in the case of nested tf.Models is counterintuitive to me.

I hope I explained my point more clearly, that's why I still consider this a bug of tf2

Edit: A small addition just to clarify even more, for tf.Layers inside an external tf.Model the code seem to works as expected, indeed the following is perfectly fine (this line of code has to be run after the code snippet which I posted in my previous answer)

tf.keras.models.Model(joint_model.layers[1].input, joint_model.layers[1].layers[3].output)

so for layers inside a model there seem to be no issue, the behaviour is fine. The problem is for models inside other models.

isaacgerg commented 2 years ago

@ricvo "Based on @isaacgerg suggestion, I would expect the following code to work."

You must have misread my explanation because in order to get my proposal to work, you have to use the .call() function which you do not use in your stated example. This might be why your stated example does not work.

ricvo commented 2 years ago

@isaacgerg You are right. That phrase is misleading, actually I think I did not express myself properly there. Let me clarify. The following alternatives both work:

>>> tf.keras.models.Model(model1.input, model2.call(model1.output))
<keras.engine.functional.Functional object at 0x7f8a902b7400>

>>> tf.keras.models.Model(model1.input, model2(model1.output))
<keras.engine.functional.Functional object at 0x7f8a90280190>

These instead do not work:

>>> tf.keras.models.Model(joint_model.input, model2.call(joint_model.layers[1].output))
*** ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 1), dtype=tf.float32, name='conv2d_input'), name='conv2d_input', description="created by layer 'conv2d_input'") at layer "conv2d". The following previous layers were accessed without issue: []
>>> tf.keras.models.Model(joint_model.input, model2(joint_model.layers[1].output))
*** ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 1), dtype=tf.float32, name='conv2d_input'), name='conv2d_input', description="created by layer 'conv2d_input'") at layer "conv2d". The following previous layers were accessed without issue: []

But actually, even where it works, we are rebuilding the network again until the node of interest (potentially in the original joint_model we could have many other layers after model2, and we should rebuild the network only until model2). Rebuilding parts of the network is actually what I was trying to avoid. Indeed that approach is useful in several cases of interest, but it is not easily suitable if one would like to get intermediate nodes in a general way.

In summary, what I was trying to say is that I believe the code snippet which I posted above in https://github.com/keras-team/tf-keras/issues/643 should work, or there should be an alternative way to get the nodes of interest from that specific graph (the one built by joint_model, in between joint_model.input and joint_model.output), without having to rebuild parts of the network. I found some ways to navigate through the graph backwards, in case the graph is a chain this is easy to follow, but they feel a bit hackish and not general beyond a chain. They could be extended beyond the chain but I was not sure how to perform path search in the graph behind keras, and since it felt an overkill I quickly dropped that.

It would be nice if there would be better support by tf Keras to access intermediate nodes in the situation depicted above. In tf1 this was easy since in the graph it was possible to ask for an output by feeding inputs in an arbitrary manner and nodes could be retrieved generally by using their path. I was not able to find this kind of flexibility in tf2, but I would be very glad to be corrected on this :)

Hope this clarifies.

isaacgerg commented 2 years ago

@ricvo Great, it sounds like we are in agreement after you clarification.

It is odd bug that graph building cant operate beyond the borders of a model. It would seem that this constraint would prevent the compilation procedure from working correctly (how would it resolve the tree to build the model?). This makes me suspect there is a bug somewhere in the tf keras code.

gowthamkpr commented 2 years ago

@qlzh727 Can you PTAL? Thanks!