WongKinYiu / yolov7

Implementation of paper - YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors
GNU General Public License v3.0
13.42k stars 4.23k forks source link

heatmap —gradcam? #130

Open qutyyds opened 2 years ago

qutyyds commented 2 years ago

I tried to check the heatmap but failed. Do you have an idea to add a visual option?

WongKinYiu commented 2 years ago

The heatmap is the prediction of objectness. OK, I will add an ipynb demo.

qutyyds commented 2 years ago

热图是对象性的预测。 好的,我将添加一个 ipynb 演示。

Thank you very much for your reply and contribution. When you finish this work, can you add a modification instruction? I want to implement this function on yolor.Recently, I modified it by imitating yolov5, but it failed. I think the modification of yolov7 may help me.

WongKinYiu commented 2 years ago

https://github.com/WongKinYiu/yolov7/blob/main/tools/visualization.ipynb

alaap001 commented 2 years ago

https://github.com/WongKinYiu/yolov7/blob/main/tools/visualization.ipynb

Hey, Thank you for providing this. I tried this and got this error initially while running the second last cell: RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

Then I altered lines obj1 = output[1][0][0, 0, :, :, 4].sigmoid().cpu().numpy() to obj1 = output[1][0][0, 0, :, :, 4].sigmoid().cpu().detach().numpy()

This solved the above error but got an error while plotting in the last cell of notebook, I am attaching the stack trace:


File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/IPython/core/formatters.py:339, in BaseFormatter.__call__(self, obj)
    337     pass
    338 else:
--> 339     return printer(obj)
    340 # Finally look for special method names
    341 method = get_real_method(obj, self.print_method)

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/IPython/core/pylabtools.py:151, in print_figure(fig, fmt, bbox_inches, base64, **kwargs)
    148     from matplotlib.backend_bases import FigureCanvasBase
    149     FigureCanvasBase(fig)
--> 151 fig.canvas.print_figure(bytes_io, **kw)
    152 data = bytes_io.getvalue()
    153 if fmt == 'svg':

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/backend_bases.py:2100, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)
   2096 ctx = (renderer._draw_disabled()
   2097        if hasattr(renderer, '_draw_disabled')
   2098        else suppress())
   2099 with ctx:
-> 2100     self.figure.draw(renderer)
   2101 bbox_artists = kwargs.pop("bbox_extra_artists", None)
   2102 bbox_inches = self.figure.get_tightbbox(renderer,
   2103         bbox_extra_artists=bbox_artists)

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/artist.py:38, in allow_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     35     if artist.get_agg_filter() is not None:
     36         renderer.start_filter()
---> 38     return draw(artist, renderer, *args, **kwargs)
     39 finally:
     40     if artist.get_agg_filter() is not None:

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/figure.py:1735, in Figure.draw(self, renderer)
   1732             # ValueError can occur when resizing a window.
   1734     self.patch.draw(renderer)
-> 1735     mimage._draw_list_compositing_images(
   1736         renderer, self, artists, self.suppressComposite)
   1738     renderer.close_group('figure')
   1739 finally:

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:137, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    135 if not_composite or not has_images:
    136     for a in artists:
--> 137         a.draw(renderer)
    138 else:
    139     # Composite any adjacent images together
    140     image_group = []

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/artist.py:38, in allow_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     35     if artist.get_agg_filter() is not None:
     36         renderer.start_filter()
---> 38     return draw(artist, renderer, *args, **kwargs)
     39 finally:
     40     if artist.get_agg_filter() is not None:

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/axes/_base.py:2630, in _AxesBase.draw(self, renderer, inframe)
   2627         a.draw(renderer)
   2628     renderer.stop_rasterizing()
-> 2630 mimage._draw_list_compositing_images(renderer, self, artists)
   2632 renderer.close_group('axes')
   2633 self.stale = False

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:137, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    135 if not_composite or not has_images:
    136     for a in artists:
--> 137         a.draw(renderer)
    138 else:
    139     # Composite any adjacent images together
    140     image_group = []

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/artist.py:38, in allow_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     35     if artist.get_agg_filter() is not None:
     36         renderer.start_filter()
---> 38     return draw(artist, renderer, *args, **kwargs)
     39 finally:
     40     if artist.get_agg_filter() is not None:

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:625, in _ImageBase.draw(self, renderer, *args, **kwargs)
    623     self._draw_unsampled_image(renderer, gc)
    624 else:
--> 625     im, l, b, trans = self.make_image(
    626         renderer, renderer.get_image_magnification())
    627     if im is not None:
    628         renderer.draw_image(gc, l, b, im)

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:914, in AxesImage.make_image(self, renderer, magnification, unsampled)
    912 bbox = Bbox(np.array([[x1, y1], [x2, y2]]))
    913 transformed_bbox = TransformedBbox(bbox, trans)
--> 914 return self._make_image(
    915     self._A, bbox, transformed_bbox,
    916     self.get_clip_box() or self.axes.bbox,
    917     magnification, unsampled=unsampled)

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:478, in _ImageBase._make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
    476 A_scaled += 0.1
    477 # resample the input data to the correct resolution and shape
--> 478 A_resampled = _resample(self, A_scaled, out_shape, t)
    479 # done with A_scaled now, remove from namespace to be sure!
    480 del A_scaled

File ~/anaconda3/envs/env_py38/lib/python3.8/site-packages/matplotlib/image.py:197, in _resample(image_obj, data, out_shape, transform, resample, alpha)
    195 if resample is None:
    196     resample = image_obj.get_resample()
--> 197 _image.resample(data, out, transform,
    198                 _interpd_[interpolation],
    199                 resample,
    200                 alpha,
    201                 image_obj.get_filternorm(),
    202                 image_obj.get_filterrad())
    203 return out

ValueError: Unsupported dtype

<Figure size 1152x864 with 12 Axes>```

Is it due to the changes I made? 
WongKinYiu commented 2 years ago

Try .detach().sigmoid().cpu().numpy() or .sigmoid().detach().cpu().numpy().

alaap001 commented 2 years ago

Tried both, it's the same.

qutyyds commented 2 years ago

Try .detach().sigmoid().cpu().numpy(). it is OK. ![Uploading Snipaste_2022-07-15_10-39-59.png…]()

alaap001 commented 2 years ago

So a quick update on this for anyone facing issues. This was due to matplotlib version. I faced this in 3.2.2, after updating to 3.5.2 error was gone.

qutyyds commented 2 years ago

I want to visually show which areas in the image contribute greatly to category classification. Is it possible to achieve this goal? like this —— https://github.com/jacobgil/pytorch-grad-cam https://github.com/jacobgil/pytorch-grad-cam/blob/master/tutorials/EigenCAM%20for%20YOLO5.ipynb

SergioG-M commented 1 year ago

Can anyone explain a bit how this works? What is exactly in outputs[1]? And what is each element in those? outputs[1][0], outputs[1][1], etc.

I guess that for each outputs[1][x] there's a plot of each RGB channel, but I don't understand what they represent. Moreover, there seem to be 4 elements in outputs[1] in the notebook, but I only have 3 for my custom model