thomasp85 / lime

Local Interpretable Model-Agnostic Explanations (R port of original Python package)
https://lime.data-imaginist.com/
Other
486 stars 110 forks source link

Error: explain function on CNN keras model #169

Closed KhawlaSeddiki closed 4 years ago

KhawlaSeddiki commented 4 years ago

I am using a Keras model which contains convolutional layers

Model
_____________________________________________________________________________________________________
Layer (type)                                                       Output Shape                                               Param #                
=====================================================================================================
conv1d_8 (Conv1D)                                                  (None, 1896, 64)                                           384                    
_____________________________________________________________________________________________________________________________________________________
batch_normalization_8 (BatchNormalization)                         (None, 1896, 64)                                           256                    
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)                                          (None, 1896, 64)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_10 (Dropout)                                               (None, 1896, 64)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_9 (Conv1D)                                                  (None, 1886, 32)                                           22560                  
_____________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization)                         (None, 1886, 32)                                           128                    
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)                                          (None, 1886, 32)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_11 (Dropout)                                               (None, 1886, 32)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_10 (Conv1D)                                                 (None, 1866, 16)                                           10768                  
_____________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization)                        (None, 1866, 16)                                           64                     
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)                                         (None, 1866, 16)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_12 (Dropout)                                               (None, 1866, 16)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_11 (Conv1D)                                                 (None, 1826, 8)                                            5256                   
_____________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization)                        (None, 1826, 8)                                            32                     
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)                                         (None, 1826, 8)                                            0                      
_____________________________________________________________________________________________________________________________________________________
dropout_13 (Dropout)                                               (None, 1826, 8)                                            0                      
_____________________________________________________________________________________________________________________________________________________
flatten_2 (Flatten)                                                (None, 14608)                                              0                      
_____________________________________________________________________________________________________________________________________________________
dense_4 (Dense)                                                    (None, 100)                                                1460900                
_____________________________________________________________________________________________________________________________________________________
dropout_14 (Dropout)                                               (None, 100)                                                0                      
_____________________________________________________________________________________________________________________________________________________
dense_5 (Dense)                                                    (None, 5)                                                  505                    
=====================================================================================================================================================
Total params: 1,500,853
Trainable params: 1,500,613
Non-trainable params: 240
model %>% fit(
  my_train, y_train, epochs = 3, verbose = 1, batch_size = 256)

My train set (my_train) and test set (my_test) are 3D arrays (#samples, 1900 , 1) reshaped from the flat matrix train and test (#samples, 1900)

I am getting this error when I am using predict_model and lime::explain functions on both 3D array and flat matrix

### Flat matrix
predict_model(x= model, newdata=test, type= 'raw')
 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Error when checking input: expected conv1d_input to have 3 dimensions, but got array with shape (19, 1900) 
### 3D array
predict_model(x= model, newdata=my_test, type= 'raw')
 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Error when checking input: expected conv1d_input to have 3 dimensions, but got array with shape (36100, 1) 

What I have to change to make it compatible with CNN keras model?

MichaelPeibo commented 4 years ago

Hi @khawkhaa I am facing similar problem, dealing with input data shape of CNN model, would like to share how you fixed it? Thanks!

KhawlaSeddiki commented 4 years ago

Hi @MichaelPeibo, I haven't found a solution. I'm still waiting for an answer from the authors.

MichaelPeibo commented 4 years ago

Hi @khawkhaa I got through by reshapeing my input data. check this