sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 111 forks source link

Support for Binary Classification Models #73

Closed EmanueleGhelfi closed 5 years ago

EmanueleGhelfi commented 5 years ago

Hi, first of all thank you for tf-explain.

Currently I'm trying to use tf-explain with a model like this one:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(initial_filters, kernel_size, activation='relu', input_shape=(256, 256, 3), padding="same")) 
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 128, 128
model.add(tf.keras.layers.Conv2D(initial_filters*2, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 64, 64
model.add(tf.keras.layers.Conv2D(initial_filters*4, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 32, 32
model.add(tf.keras.layers.Conv2D(initial_filters*8, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 16, 16
model.add(tf.keras.layers.Conv2D(initial_filters*16, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 8, 8
model.add(tf.keras.layers.Conv2D(initial_filters*32, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation="relu"))
model.add(tf.keras.layers.Dense(1))

This is a model used for a binary classification task for the cat vs dog dataset. Using the tf-explain callback GradCAM does not seem to provide correct result.

I think this is due to the following line in the code:

https://github.com/sicara/tf-explain/blob/master/tf_explain/core/grad_cam.py#L85

where basically you take the index corresponding to the selected class. A better approach would be to check the shape of the model output and:

What do you think about this issue and this (possible) fix?

RaphaelMeudec commented 5 years ago

@EmanueleGhelfi Hi Emanuele! I think this would induce some complexity in the code for a particular case. Maybe you can switch your final Dense(1) layer into a Dense(2, activation='softmax')? Then you would be able to select class 1 or 2

EmanueleGhelfi commented 5 years ago

Yes, sure I can. I think that at least you can raise an Exception and you can add this to the documentation, If not present.

On Thu, 29 Aug 2019, 14:00 Raphael Meudec, notifications@github.com wrote:

@EmanueleGhelfi https://github.com/EmanueleGhelfi Hi Emanuele! I think this would induce some complexity in the code for a particular case. Maybe you can switch your final Dense(1) layer into a Dense(2, activation='softmax')? Then you would be able to select class 1 or 2

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/sicara/tf-explain/issues/73?email_source=notifications&email_token=AD7R3J5DUU4AX36ME5O6M5DQG6243A5CNFSM4IRTWC72YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5OHOWQ#issuecomment-526153562, or mute the thread https://github.com/notifications/unsubscribe-auth/AD7R3J4WDAP35I5KLIPD5F3QG6243ANCNFSM4IRTWC7Q .

RaphaelMeudec commented 5 years ago

I'm trying to avoid Exception to prevent trainings from breaking just because of a callback. Might add a warning though. Thanks!