albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.26k stars 233 forks source link

[BUG] `AnalyzerNetworkBase` analyzers error when using `BatchNormalization` layers #292

Open adrhill opened 1 year ago

adrhill commented 1 year ago

On iNNvestigate v2.0.1, creating an analyzer inheriting from AnalyzerNetworkBase errors when the model contains a BatchNormalization layer, e.g.:

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50]

This might be due to batch normalisation layers keeping moving averages of the mean and standard deviation of the training data, causing problems with the Keras history when reversing the computational graph in iNNvestigate's create_analyzer_model.

Minimal example reproducing the issue

import numpy as np
import tensorflow as tf
from keras.layers import BatchNormalization, Dense
from keras.models import Sequential

import innvestigate

tf.compat.v1.disable_eager_execution()

input_shape = (50,)
x = np.random.rand(100, *input_shape)
y = np.random.rand(100, 2)

model1 = Sequential()
model1.add(Dense(10, input_shape=input_shape))
model1.add(Dense(2))

model2 = Sequential()
model2.add(Dense(10, input_shape=input_shape))
model2.add(BatchNormalization())
model2.add(Dense(2))

def run_analysis(model):
    model.compile(optimizer="adam", loss="mse")
    model.fit(x, y, epochs=10, verbose=0)

    analyzer = innvestigate.create_analyzer("gradient", model)
    analyzer.analyze(x)

print("Model without BatchNormalization:")  # passes
run_analysis(model1)
print("Model with BatchNormalization:")     # errors
run_analysis(model2)

Full stacktrace

``` Model with BatchNormalization: /Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py:1480: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'. ret = tf_session.TF_SessionRunCallable(self._session._session, Traceback (most recent call last): File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 35, in run_analysis(model2) File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 29, in run_analysis analyzer.analyze(x) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 250, in analyze self.create_analyzer_model() File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 196, in create_analyzer_model self._analyzer_model = kmodels.Model( File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper result = method(self, *args, **kwargs) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 146, in __init__ self._init_graph_network(inputs, outputs) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper result = method(self, *args, **kwargs) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 181, in _init_graph_network base_layer_utils.create_keras_history(self._nested_outputs) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 175, in create_keras_history _, created_layers = _create_keras_history_helper(tensors, set(), []) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper processed_ops, created_layers = _create_keras_history_helper( File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper processed_ops, created_layers = _create_keras_history_helper( File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper processed_ops, created_layers = _create_keras_history_helper( [Previous line repeated 3 more times] File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 251, in _create_keras_history_helper constants[i] = backend.function([], op_input)([]) File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/backend.py", line 4275, in __call__ fetched = self._callable_fn(*array_vals, File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py", line 1480, in __call__ ret = tf_session.TF_SessionRunCallable(self._session._session, tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50] [[{{node dense_2_input}}]] ```
yap231995 commented 1 year ago

Hello, i am also encountering this issue. How do have a work around it? I saw that you can try to change to Dense layer. Is there a code that i could reference from?

adrhill commented 1 year ago

The workaround using a Dense layer is described here: https://github.com/albermax/innvestigate/issues/283#issuecomment-1276112045