PaddlePaddle / InterpretDL

InterpretDL: Interpretation of Deep Learning Models,基于『飞桨』的模型可解释性算法库。
https://interpretdl.readthedocs.io
Apache License 2.0
237 stars 38 forks source link

运行example_grad_cam_cv.ipynb时报错 #40

Closed littletomatodonkey closed 2 years ago

littletomatodonkey commented 2 years ago

(optional) Briefly introduce yourself.


🐛 Bug

运行example_grad_cam_cv.ipynb时会报错

To Reproduce

Steps to reproduce the behavior:

  1. 直接运行该notebook
  2. 报错信息如下
---------------------------------------------------------------------------
AxisError                                 Traceback (most recent call last)
/paddle/temp/ipykernel_14605/90004086.py in <module>
      9         label=None,
     10         visual=True,
---> 11         save_path=None)

/paddle/code/gry/InterpretDL/interpretdl/interpreter/gradient_cam.py in interpret(self, inputs, target_layer_name, label, resize_to, crop_to, visual, save_path)
     99 
    100         # the core algorithm
--> 101         cam_weights = np.mean(g, (2, 3), keepdims=True)
    102         heatmap = cam_weights * f
    103         heatmap = heatmap.mean(1)

<__array_function__ internals> in mean(*args, **kwargs)

/usr/local/python3.7.0/lib/python3.7/site-packages/numpy/core/fromnumeric.py in mean(a, axis, dtype, out, keepdims, where)
   3439 
   3440     return _methods._mean(a, axis=axis, dtype=dtype,
-> 3441                           out=out, **kwargs)
   3442 
   3443 

/usr/local/python3.7.0/lib/python3.7/site-packages/numpy/core/_methods.py in _mean(a, axis, dtype, out, keepdims, where)
    165     is_float16_result = False
    166 
--> 167     rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
    168     if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
    169         warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)

/usr/local/python3.7.0/lib/python3.7/site-packages/numpy/core/_methods.py in _count_reduce_items(arr, axis, keepdims, where)
     74         items = nt.intp(1)
     75         for ax in axis:
---> 76             items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
     77     else:
     78         # TODO: Optimize case when `where` is broadcast along a non-reduction

报错的命令为

image
holyseven commented 2 years ago

Hi,

报错原因是输入给numpy.mean的数据格式不对。

我这边没有出现这个问题,有可能是图片读入的时候出问题了。可以查看一下以下代码在jupyter下,是否能正常显示图片?

img_path = 'assets/catdog.png'
x = Image.fromarray(read_image(img_path)[0])
x

image

如果图片正常显示的话,麻烦将python环境信息贴一下,我看看是否能复现上述问题。

littletomatodonkey commented 2 years ago

图片可以正常显示的,我刚才测试了一下,paddle2.3.1可以正常使用,我自己编译的paddle无法正常显示(commit id : a635a8a5c4ec914f57058bdcccf854d620ce5f42 , 约8月29号编译的)

image

image

littletomatodonkey commented 2 years ago

更新一下,paddle2.3.2也可以,可能与我编译的包仅用来做推理有关系,多谢