pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.95k stars 499 forks source link

Visualization feature is not working #1165

Open eszaher opened 1 year ago

eszaher commented 1 year ago

🐛 Bug

<I am following the example on Model Interpretation for Pretrained ResNet Model. When I try to make the visualization I receive an error -->

To Reproduce

Steps to reproduce the behavior:

  1. Run the tutorial given in the link https://captum.ai/tutorials/Resnet_TorchVision_Interpret#Model-Interpretation-for-Pretrained-ResNet-Model
6 _ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)), 7 np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), 8 method='heat_map', 9 cmap=default_cmap, 10 show_colorbar=True, 11 sign='positive', 12 outlier_perc=1) File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/captum/attr/_utils/visualization.py:250, in visualize_image_attr(attr, original_image, method, sign, plt_fig_axis, outlier_perc, cmap, alpha_overlay, show_colorbar, title, fig_size, use_pyplot) 248 plt_axis.set_yticklabels([]) 249 plt_axis.set_xticklabels([]) --> 250 plt_axis.grid(b=False) 252 heat_map = None 253 # Show original image File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axes/_base.py:3194, in _AxesBase.grid(self, visible, which, axis, **kwargs) 3192 _api.check_in_list(['x', 'y', 'both'], axis=axis) 3193 if axis in ['x', 'both']: -> 3194 self.xaxis.grid(visible, which=which, **kwargs) 3195 if axis in ['y', 'both']: 3196 self.yaxis.grid(visible, which=which, **kwargs) File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:1660, in Axis.grid(self, visible, which, **kwargs) 1657 if which in ['major', 'both']: 1658 gridkw['gridOn'] = (not self._major_tick_kw['gridOn'] 1659 if visible is None else visible) -> 1660 self.set_tick_params(which='major', **gridkw) 1661 self.stale = True File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:932, in Axis.set_tick_params(self, which, reset, **kwargs) 919 """ 920 Set appearance parameters for ticks, ticklabels, and gridlines. 921 (...) 929 gridlines. 930 """ 931 _api.check_in_list(['major', 'minor', 'both'], which=which) --> 932 kwtrans = self._translate_tick_params(kwargs) 934 # the kwargs are stored in self._major/minor_tick_kw so that any 935 # future new ticks will automatically get them 936 if reset: File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:1076, in Axis._translate_tick_params(kw, reverse) 1074 for key in kw_: 1075 if key not in allowed_keys: -> 1076 raise ValueError( 1077 "keyword %s is not recognized; valid keywords are %s" 1078 % (key, allowed_keys)) 1079 kwtrans.update(kw_) 1080 return kwtrans ValueError: keyword grid_b is not recognized; valid keywords are ['size', 'width', 'color', 'tickdir', 'pad', 'labelsize', 'labelcolor', 'zorder', 'gridOn', 'tick1On', 'tick2On', 'label1On', 'label2On', 'length', 'direction', 'left', 'bottom', 'right', 'top', 'labelleft', 'labelbottom', 'labelright', 'labeltop', 'labelrotation', 'grid_agg_filter', 'grid_alpha', 'grid_animated', 'grid_antialiased', 'grid_clip_box', 'grid_clip_on', 'grid_clip_path', 'grid_color', 'grid_dash_capstyle', 'grid_dash_joinstyle', 'grid_dashes', 'grid_data', 'grid_drawstyle', 'grid_figure', 'grid_fillstyle', 'grid_gapcolor', 'grid_gid', 'grid_in_layout', 'grid_label', 'grid_linestyle', 'grid_linewidth', 'grid_marker', 'grid_markeredgecolor', 'grid_markeredgewidth', 'grid_markerfacecolor', 'grid_markerfacecoloralt', 'grid_markersize', 'grid_markevery', 'grid_mouseover', 'grid_path_effects', 'grid_picker', 'grid_pickradius', 'grid_rasterized', 'grid_sketch_params', 'grid_snap', 'grid_solid_capstyle', 'grid_solid_joinstyle', 'grid_transform', 'grid_url', 'grid_visible', 'grid_xdata', 'grid_ydata', 'grid_zorder', 'grid_aa', 'grid_c', 'grid_ds', 'grid_ls', 'grid_lw', 'grid_mec', 'grid_mew', 'grid_mfc', 'grid_mfcalt', 'grid_ms'] -->

Expected behavior

Environment

Describe the environment used for Captum



 - PyTorch: '2.1.0.dev20230711+cu121'
 - OS: Linux
 - Captum from source: pip install git+https://github.com/pytorch/captum.git
 - Build command you used (if compiling from source):
 - Python 3.10.9
 - CUDA version: 12.1
 - GPU : A100
znacer commented 1 year ago

Hi, I manage to reproduce this error. It can be corrected by a change in the line 250 of captum/attr/_utils/visualization.py:
250 plt_axis.grid(b=False) to
250 plt_axis.grid(visible=False)

After this change, it worked on my side.

The b argument seems to have been removed from matplotlib.axes.Axes.grid since version 3.5. If you cannot or don't want to modify Captum files, donwgrade to matplotlib 3.4.3 might work (not tested)

znacer commented 1 year ago

The issue has already been fixed: #1118 You can install directely from source to make it works :
pip install git+https://github.com/pytorch/captum.git

kamibrumi commented 1 year ago

Thanks, that solved my issue over here, too!