sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 111 forks source link

Compatibility with recursive models #63

Open sayakpaul opened 5 years ago

sayakpaul commented 5 years ago

Hi. First of all, kudos for the amazing initiative. Having something like as a Callback option in tf.keras is just amazing. However, as there are not any demo notebooks available currently to enable an interested developer to try out tf-explain, I was preparing one with a fine-tuned model using the CIFAR10 dataset. However, due to the lack of clarity on what to exactly specify in the layers_name argument while defining a callback, maybe I am getting an error after one epoch (error trace is attached error_trace.txt ).

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(X_val, y_val),
        layers_name=["block5_conv1"],
        output_dir="logs",
    ),
] 

My notebook is available via Colab if you want to see. Let me know why is this happening if possible and I shall be able to prepare the notebook.

RaphaelMeudec commented 5 years ago

Hey @sayakpaul, thanks for the kind words. The trick is that ActivationsVisualizationCallback should be used to visualize the outputs of a convolutional layer (and you are using it on a dense layer, so it cannot manage to create an image to visualize the layer). You can check the examples/callbacks. Your usage of layers_name is the right one by the way, it is indeed a list of the target layers names.

Hope this answer helps, I might make the examples more accessible in the README. Thanks again for your feedback!

sayakpaul commented 5 years ago

HI @RaphaelMeudec if you check the code block I provided, I specified a Conv layer only and still, it failed. The example you provided:

tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, 'grad_cam_target'),

I am assuming layers_name is an optional argument then?

RaphaelMeudec commented 5 years ago

There was a small typo in the examples. It is now fixed. layers_name is not optional and should be a List[str]. The issue with your example might come from the fact that you're inserting a bes model inside a model. The library depends on tf.keras method model.get_layer() which is not recursive I think. I'll look into it, thanks for pointing it out!

sayakpaul commented 5 years ago

@RaphaelMeudec I think making it recursive will actually help a lot of practitioners and researchers since transfer learning and model fine-tuning is used vehemently.

RaphaelMeudec commented 5 years ago

@sayakpaul I've started looking into it, and Tensorflow makes it a bit hard to create a subgraph between the input of a model and a layer of a submodel. The solution to this is to flatten the submodel, so that the main model doesn't contain any submodels.

sayakpaul commented 5 years ago

Thanks @RaphaelMeudec. Could you provide an example?

Rubenkl commented 4 years ago

@RaphaelMeudec how would you proceed with accessing a 'subgraph' by flattening the model?

I proceeded by trying: explainer.explain(validation_data=(X,Y), model=model.layers[0], layer_name='conv5_block16_2_conv', class_index=1 where model.layers[0] is a pretrained network block. However, the GradCAM returns empty..

Edit:

I finally managed to flatten the model by using the Functional API of Keras (instead of Sequential).

Model:

    backbone = DenseNet121(input_shape=input_shape, include_top=False)
    lastLayer = backbone.output
    predictions = K.layers.Dense(num_classes, activation='softmax')(lastLayer)
    model = K.Model(inputs=backbone.input, outputs=predictions)
gabarlacchi commented 4 years ago

Any updates on this? I am using a network such : `inc = InceptionV3( weights=imagenet, include_top=False, pooling=max, input_shape=(shape[0], shape[1], 3) ) for layer in inc.layers: layer.trainable = False

model = Sequential() model.add( TimeDistributed(inc, input_shape=(chunkframes, shape[0], shape[1], dataset.n_channels)) ) model.add(Dropout(0.2)) model.add(LSTM(128, activation='relu', return_sequences=False, kernel_initializer=he_uniform(), bias_initializer='zeros', dropout=0.5)) model.add(BatchNormalization()) model.add(Dense(4, activation='sigmoid', kernel_initializer=glorot_normal(), bias_initializer='zeros'))`

I would like to visualiza as heatmap the activations for every frames that is fet into the LSTM. Is that possible?