sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 110 forks source link

Example or support for .fit_generator #100

Open twsl opened 4 years ago

twsl commented 4 years ago

I'm currently using the tensorflow.keras.preprocessing.image.ImageDataGenerator and the .flow_from_directory, which generates batches of data in order to finally use .fit_generator instead of .fit and I'm not too sure, how to implement the subset. An example would be awesome!

twsl commented 4 years ago

A possible solution could look like this:

valid_batches = valid_datagen.flow_from_directory(...)
(x,y) = valid_batches.next()
validation_class_zero = (np.array([
    el for el, label in zip(x,y)
    if np.all(np.argmax(label) == 0)
][0:5]), None)

But when i try to select the first conv layer in InceptionResNetV2, I get the following error: ValueError: Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=1. Full shape received: [0] And I'm not sure, if this is related to my workaround or something else.

RaphaelMeudec commented 4 years ago

Terribly sorry for not answering earlier. My guess is that the condition np.all(np.argmax(label) == 0) is not met when you use the flow_from_directory, so your validation class zero is empty and hence has shape [0]. To validate this, could you tell what y looks like?

I'm looking for the next release to support ImageDataGenerator and tf.data.Dataset, which should solve this.

twsl commented 4 years ago

Terribly sorry for not answering earlier 😄 Looks like your thought was on point, the condition np.all is aparently not always met.

(x,y) = valid_batches.next()
for el, label in zip(x,y):
    print(el.shape)
    print(label.shape)
    print(np.argmax(label))

validation_class_zero = (np.array([
    el for el, label in zip(x,y)
    if np.all(np.argmax(label) == 0)
][0:5]), None)
print(validation_class_zero)

(300, 300, 3) (3,) 0 (300, 300, 3) (3,) 0 (300, 300, 3) (3,) 1 (300, 300, 3) (3,) 1 (300, 300, 3) (3,) 1 (300, 300, 3) (3,) 1 (array([], dtype=float64), None)

This is probably due to the fact, that I rely on transfer learning and do not have a sufficient amount of samples.

VictorW96 commented 4 years ago

I have the same problem. Is there support for the ImageDataGenerator in version 2.1.0? I haven't found it yet.