marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.62k stars 1.81k forks source link

Lime image: change tutorial so that preprocessing function is folded inside predict_fn #69

Closed marcotcr closed 5 years ago

marcotcr commented 7 years ago

Different neural networks may preprocess images differently (e.g. resnet vs inception). Our image explainer assumes the image is in a format that skimage can understand. In the tutorial, we should remove inception-specific stuff and just use a normal image as input, with preprocessing + prediction inside the prediction function.

hainguyenct commented 6 years ago

Hi @marcotcr, If I used a image preprocessed by resnet50

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
model = ResNet50(weights='imagenet')
img_path = 'cat.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#preds = model.predict(x)
print x
print('Predicted:', decode_predictions(preds, top=3)[0])
plt.imshow(x[0]/255)

the code above will give an error likes this: screen shot 2018-02-11 at 15 23 01

How do I visualize this image with plt.imshow? Thank you

marcotcr commented 6 years ago

That code works for me, but the image looks funny because of the preprocessing. If you run imshow before preprocess_input, it looks fine. The way to use LIME with resnet is to have a predict_fn like the following:

def predict_fn(x):
    model.predict(preprocess_input(x))
doccrate commented 6 years ago

I am also struggling with understanding precisely what should be used as image. I am using some of the tensor-flow mobilenet retraining code. My predict function takes output from the following:

`def read_rawimage_from_image_file(file_name, input_height=299, input_width=299, input_mean=0, input_std=255): input_name = "file_reader" output_name = "image_reader" file_reader = tf.read_file(file_name, input_name) if file_name.endswith(".png"): image_reader = tf.image.decode_png(file_reader, channels = 3, name='png_reader') elif file_name.endswith(".gif"): image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name='gif_reader')) elif file_name.endswith(".bmp"): image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') else: image_reader = tf.image.decode_jpeg(file_reader, channels = 3, name='jpeg_reader') sess = tf.Session() result = sess.run(image_reader)

return result` e.g. a tensorflow decoded imagefile. However, I get the following error from LIME Traceback (most recent call last): File "/home/cait/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/home/cait/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/cait/tensorflow-for-poets-2/scripts/label_image_fcn.py", line 243, in explanation = explainer.explain_instance(image_raw, prediction_fn, top_labels=5, hide_color=0, num_samples=1000) File "/home/cait/anaconda3/lib/python3.6/site-packages/lime/lime_image.py", line 187, in explain_instance batch_size=batch_size) File "/home/cait/anaconda3/lib/python3.6/site-packages/lime/lime_image.py", line 248, in data_labels preds = classifier_fn(np.array(imgs)) File "/home/cait/tensorflow-for-poets-2/scripts/label_image_fcn.py", line 153, in prediction_fn input_std=input_std) File "/home/cait/tensorflow-for-poets-2/scripts/label_image_fcn.py", line 91, in read_tensor_from_rawimage resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) File "/home/cait/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_image_ops.py", line 2372, in resize_bilinear align_corners=align_corners, name=name) File "/home/cait/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/cait/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op op_def=op_def) File "/home/cait/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1734, in init control_input_ops) File "/home/cait/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1570, in _create_c_op raise ValueError(str(e)) ValueError: Shape must be rank 4 but is rank 5 for 'ResizeBilinear_1' (op: 'ResizeBilinear') with input shapes: [1,10,14,128,3], [2].

Can you please tell me what I should be using instead of the raw image binary for input so that I can use both the mobilenet scripts as well as LIME. Thanks.

marcotcr commented 6 years ago

I'm guessing you called explain_instance with an image that is [1, *, *, *], when it should be [*, * \,*].