albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.24k stars 235 forks source link

[BUG] LRPZ Flatten #286

Closed tydickinson29 closed 1 year ago

tydickinson29 commented 1 year ago

Hi! I would like to use your package to explain my CNN, specifically using the LRP and input*gradient methods. The architecture is based on a u-net and can be seen here:

import tensorflow as tf
from tensorflow import keras
import innvestigate

def double_convolution(x, num_filters, **kwargs):
    x = keras.layers.Conv2D(num_filters, activation='relu', **kwargs)(x)
    x = keras.layers.Conv2D(num_filters, activation='relu', **kwargs)(x)
    return x

def downsample(x, num_filters, **kwargs):
    convs = double_convolution(x, num_filters, **kwargs)
    pooler = keras.layers.MaxPool2D(2, padding=kwargs['padding'],
                                    data_format=kwargs['data_format'])(convs)
    drop = keras.layers.Dropout(0.3)(pooler)
    return convs, drop

def upsample(x, num_filters, concat_data, **kwargs):
    axis = -1 if kwargs['data_format'] == 'channels_last' else 0
    rev_conv = keras.layers.Conv2DTranspose(num_filters, strides=2, **kwargs)(x)
    concat = keras.layers.concatenate([rev_conv, concat_data], axis=axis)
    drop = keras.layers.Dropout(0.3)(concat)
    convs = double_convolution(drop, num_filters, **kwargs)
    return convs

def build_unet(**kwargs):
    kwargs.setdefault('kernel_size', 3)
    kwargs.setdefault('padding', 'same')
    kwargs.setdefault('data_format', 'channels_last')
    kwargs.setdefault('kernel_initializer', 'he_normal')

    #input
    inputs = keras.layers.Input(shape=(64,128,18))

    #encoder
    conv1, layer1 = downsample(inputs, 64, **kwargs)
    conv2, layer2 = downsample(layer1, 128, **kwargs)
    conv3, layer3 = downsample(layer2, 256, **kwargs)
    conv4, layer4 = downsample(layer3, 512, **kwargs)

    #bottleneck
    layer5 = double_convolution(layer4, 1024, **kwargs)

    #decoder
    layer6 = upsample(layer5, 512, conv4, **kwargs)
    layer7 = upsample(layer6, 256, conv3, **kwargs)
    layer8 = upsample(layer7, 128, conv2, **kwargs)
    layer9 = upsample(layer8, 64, conv1, **kwargs)

    #output
    kwargs.pop('kernel_size')
    output = keras.layers.Conv2D(filters=1, kernel_size=1, activation='sigmoid', **kwargs)(layer9)

    model = keras.Model(inputs=inputs, outputs=output, name='my_unet')
    return model

model = build_unet()
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='binary_crossentropy',
              metrics=[tf.keras.metrics.AUC(curve='PR')])

#I load in my data here, but omitted. 
#The predictors have shape (n, 64, 128, 18) and the predictand has shape (n, 64, 128, 1) where n is my sample size
history = model.fit(x=train_data,
                    y=train_labels,
                    batch_size=52,
                    epochs=100,
                    verbose=2,
                    callbacks=[checkpoint, csv],
                    validation_split=0.1) #using ModelCheckpoint() and CSVLogger()

#I train in a separate script, but wanted to show process above
model = keras.models.load_model(path_to_saved_model)
lrp = innvestigate.analyzer.relevance_based.relevance_analyzer.LRPZ(model)
output = lrp.analyze(predictors[0])

However, I get the following error:

AttributeError: Exception encountered when calling layer "flatten" (type Flatten).

'list' object has no attribute 'shape'

Call arguments received:
  • inputs=['tf.Tensor(shape=(None, 64, 128, 1), dtype=float32)']

I am slightly confused, especially considering my architecture does not have any flatten layers to begin with. Any insights into the origin of the error and anything I can to do alleviate the situation is greatly appreciated. Thank you!

Platform information

Faranehhad commented 1 year ago

Hello, I have the same error. Did you find any solution? Thanks.

michaelmontalbano commented 1 year ago

Hi, I ran into this same problem, but passed it by changing line 119 of analyzer/network_base.py from

model_output = klayers.Flatten()(model_output) 

to

model_output = klayers.Flatten()(model.outputs[0])

I think the model_output: list[Tensor] may be the problem (earlier in the prepare_model function). However, there are still problems further in that I have not yet worked out. Hope this helps.

tydickinson29 commented 1 year ago

Thanks @michaelmontalbano for the comment! It is indeed a shape issue. I simply made a new keras.Model by adding a reshape layer to my model:

model = keras.models.load_model(path_to_saved_model)
orig_shape = model.output_shape
new_shape = orig_shape[1]*orig_shape[2]
reshp = tf.keras.layers.Reshape((new_shape,), input_shape=orig_shape)(model.layers[-1].output)
new_model = keras.Model(inputs=model.inputs, outputs=[reshp])
lrp = innvestigate.analyzer.relevance_based.relevance_analyzer.LRPZ(new_model)

Now, running lrp.analyze() works as expected!

Note: I also ended up adding tf.compat.v1.disable_eager_execution() at the beginning of the program.