raghakot / keras-vis

Neural network visualization toolkit for keras
https://raghakot.github.io/keras-vis
MIT License
2.97k stars 664 forks source link

'NoneType' error #180

Closed parthnatekar closed 5 years ago

parthnatekar commented 5 years ago

Getting this error on running the example for visualize-cam at https://github.com/raghakot/keras-vis/blob/master/examples/vggnet/activation_maximization.ipynb

Upgraded to latest version, still getting the same error

from vis.visualization import visualize_cam

for modifier in [None, 'guided', 'relu']:
    plt.figure()
    f, ax = plt.subplots(1, 2)
    plt.suptitle("vanilla" if modifier is None else modifier)
    for i, img in enumerate([img1, img2]):    
        # 20 is the imagenet index corresponding to `ouzel`
        grads = visualize_cam(model, layer_idx, filter_indices=1, 
                              seed_input=img, backprop_modifier=modifier)        
        # Lets overlay the heatmap onto original image.    
        jet_heatmap = np.uint8(cm.jet(grads)[..., :3] * 255)
        ax[i].imshow(overlay(jet_heatmap, img))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-19-d5fe349b4e57> in <module>()
      9         # 20 is the imagenet index corresponding to `ouzel`
     10         grads = visualize_cam(model, layer_idx, filter_indices=1, 
---> 11                               seed_input=img, backprop_modifier=modifier)        
     12         # Lets overlay the heatmap onto original image.
     13         jet_heatmap = np.uint8(cm.jet(grads)[..., :3] * 255)

1 frames
/usr/local/lib/python3.6/dist-packages/vis/visualization/saliency.py in visualize_cam_with_losses(input_tensor, losses, seed_input, penultimate_layer, grad_modifier)
    178     # Generate heatmap by computing weight * output over feature maps
    179     output_dims = utils.get_img_shape(penultimate_output)[2:]
--> 180     heatmap = np.zeros(shape=output_dims, dtype=K.floatx())
    181     for i, w in enumerate(weights):
    182         if channel_idx == -1:

TypeError: 'NoneType' object cannot be interpreted as an integer
parthnatekar commented 5 years ago

It seems the problem arises from not having a fixed input shape.

144