keisen / tf-keras-vis

Neural network visualization toolkit for tf.keras
https://keisen.github.io/tf-keras-vis-docs/
MIT License
311 stars 45 forks source link

AttributeError: 'NoneType' object has no attribute 'ndim' #101

Open returnv01d opened 10 months ago

returnv01d commented 10 months ago

Hello, I have problem with using Attentions visualization (GradCam++ in that case). I have preprocessing layer in my model, which rescales images like that: layers.Rescaling(1./255, input_shape=(constants.IMG_HEIGHT, constants.IMG_WIDTH, 3)), The problem is, when I directly load images to visualize, I get the following error:

Traceback (most recent call last):
  File "predict_2.py", line 52, in <module>
    cam = gradcam(score, X, penultimate_layer=-1)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/tf_keras_vis/gradcam_plus_plus.py", line 106, in __call__
    score_values = tf.reshape(score_values, score_values.shape + (1, ) * (grads.ndim - 1))
                                                                          ^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'ndim'

But when I rescale them before visualizing, I got no error, but visualizations seems to not work correctly, heatmap is plain and doesn't show any attention maps, example at the bottom. My code for loading images for visualizations (taken and customized from examples in docs):

image_titles = ['Dog', 'Cat', 'Cat']
img1 = load_img("./kaggle/test/18.jpg", target_size=constants.SHAPE)
img2 = load_img("./kaggle/test/19.jpg", target_size=constants.SHAPE)
img3 = load_img("./kaggle/test/20.jpg", target_size=constants.SHAPE)
img1 = np.asarray(img1)
img2 = np.asarray(img2)
img3 = np.asarray(img3)
images = np.asarray([img1, img2, img3])

X = images

cam = gradcam(score, X, penultimate_layer=-1)

f, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
for i, title in enumerate(image_titles):
    heatmap = np.uint8(cm.jet(cam[i])[..., :3] * 255)
    ax[i].set_title(title, fontsize=16)
    ax[i].imshow(images[i])
    ax[i].imshow(heatmap, cmap='jet', alpha=0.5)
    ax[i].axis('off')
plt.tight_layout()

And when i do like X = images/255 there is no error, but the heatmap is plain, like this: image

Should I stop using resacling layer, or there is another way to load images for that case?

uxdiin commented 1 month ago

You have to put tensor object as the gradcam input, not numpy array. I make the same mistake as you.