MisaOgura / flashtorch

Visualization toolkit for neural networks in PyTorch! Demo -->
https://youtu.be/18Iw4qYqfPo
MIT License
734 stars 87 forks source link

use of custom model throws an error #34

Closed SarfarazHabib closed 4 years ago

SarfarazHabib commented 4 years ago

Hi, Thanks for your work. I am trying to use "Activation maximization" notebook with my own custom model. The difference is that my model takes an input of (56,56,24). The gradient ascent function accepts my model without an error. I perform my own transformation on a random generated input and the input shape is torch.Size([1, 24, 56, 56]).

When I call visualize function giving it any intermediate layer and filter, it throws below error:


TypeError Traceback (most recent call last)

in ----> 1 g_ascent.visualize(layer2_0_conv2, title='layer1_0_conv2') ~/.local/lib/python3.6/site-packages/flashtorch/activmax/gradient_ascent.py in visualize(self, layer, filter_idxs, lr, num_iter, num_subplots, figsize, title, return_output) 210 num_iter, 211 len(filter_idxs), --> 212 title=title) 213 214 if return_output: ~/.local/lib/python3.6/site-packages/flashtorch/activmax/gradient_ascent.py in _visualize_filters(self, layer, filter_idxs, num_iter, num_subplots, title) 347 standardize_and_clip(output[-1], 348 saturation=0.15, --> 349 brightness=0.7))) 350 351 plt.subplots_adjust(wspace=0, hspace=0); # noqa ~/.local/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs) 1563 def inner(ax, *args, data=None, **kwargs): 1564 if data is None: -> 1565 return func(ax, *map(sanitize_sequence, args), **kwargs) 1566 1567 bound = new_sig.bind(ax, *args, **kwargs) ~/.local/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs) 356 f"%(removal)s. If any parameter follows {name!r}, they " 357 f"should be pass as keyword, not positionally.") --> 358 return func(*args, **kwargs) 359 360 return wrapper ~/.local/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs) 356 f"%(removal)s. If any parameter follows {name!r}, they " 357 f"should be pass as keyword, not positionally.") --> 358 return func(*args, **kwargs) 359 360 return wrapper ~/.local/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs) 5624 resample=resample, **kwargs) 5625 -> 5626 im.set_data(X) 5627 im.set_alpha(alpha) 5628 if im.get_clip_path() is None: ~/.local/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A) 697 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): 698 raise TypeError("Invalid shape {} for image data" --> 699 .format(self._A.shape)) 700 701 if self._A.ndim == 3: TypeError: Invalid shape (56, 56, 24) for image data Can anyone help me with this please.
MisaOgura commented 4 years ago

Hi @SarfarazHabib,

Thanks for your question.

The reason why it's giving an error is that currently the visualize method is passing the output of the gradient ascent to matplotlib which expects the input to be an RBG image.

It is on my todo list to make the API more generalisable to various custom inputs and not just RGB or grayscale images.

As a quick solution meanwhile, I suggest below:

  1. Delete this line so the conditional becomes just
if isinstance(module, nn.modules.conv.Conv2d):
  1. Use optimize api, rather than visualize, which returns raw gradients accumulated over the number of iterations as a list

  2. Plot the gradients/channels of your interest

This is untested, but hope this helps - let me know how it goes.

MisaOgura commented 4 years ago

@SarfarazHabib please feel free to reopen the issue or open a new one if required.