raghakot / keras-vis

Neural network visualization toolkit for keras
MIT License
2.98k stars 664 forks source link

support for conv3d #52

Closed dearkafka closed 7 years ago

dearkafka commented 7 years ago

Hello, thanks for a great package, I found out that current version does not support 3d images (conv3d) which is expected, but it would be great if you could add this feature.

raghakot commented 7 years ago

It does. N dim inputs should work in fact. Do you have a gist that shows otherwise?

dearkafka commented 7 years ago

Yes. I'm trying basic example:

# The name of the layer we want to visualize
# (see model definition in vggnet.py)
layer_name = 'conv1'
layer_idx = [idx for idx, layer in enumerate(model.layers) if layer.name == layer_name][0]

# Visualize all filters in this layer.
filters = np.arange(get_num_filters(model.layers[layer_idx]))

# Generate input image for each filter. Here `text` field is used to overlay `filter_value` on top of the image.
vis_images = []
for idx in filters:
    img = visualize_activation(model, layer_idx, filter_indices=idx) 
    img = utils.draw_text(img, str(idx))

# Generate stitched image palette with 8 cols.
stitched = utils.stitch_images(vis_images, cols=8)    

and it throws like that:

AssertionError                            Traceback (most recent call last)
<ipython-input-19-20fc5da0da03> in <module>()
     10 vis_images = []
     11 for idx in filters:
---> 12     img = visualize_activation(model, layer_idx, filter_indices=idx)
     13     img = utils.draw_text(img, str(idx))
     14     vis_images.append(img)

/usr/local/lib/python3.5/dist-packages/vis/visualization.py in visualize_activation(model, layer_idx, filter_indices, seed_img, text, act_max_weight, lp_norm_weight, tv_weight, **optimizer_params)
    108     ]
--> 110     opt = Optimizer(model.input, losses, norm_grads=False)
    111     img = opt.minimize(**optimizer_params)[0]
    112     if text:

/usr/local/lib/python3.5/dist-packages/vis/optimizer.py in __init__(self, img_input, losses, wrt, norm_grads)
     33             # Perf optimization. Don't build loss function with 0 weight.
     34             if weight != 0:
---> 35                 loss_fn = weight * loss.build_loss()
     36                 overall_loss = loss_fn if overall_loss is None else overall_loss + loss_fn
     37                 self.loss_names.append(loss.name)

/usr/local/lib/python3.5/dist-packages/vis/regularizers.py in build_loss(self)
     49         \left ( x(h+1, w, c) - x(h, w, c) \right )^{2} \right )^{\frac{\beta}{2}}$$
     50         """
---> 51         assert 4 == K.ndim(self.img)
     52         a = K.square(self.img[utils.slicer[:, :, 1:, :-1]] - self.img[utils.slicer[:, :, :-1, :-1]])
     53         b = K.square(self.img[utils.slicer[:, :, :-1, 1:]] - self.img[utils.slicer[:, :, :-1, :-1]])


correct me if I'm wrong, it seems the problem with dimensions of tensors

raghakot commented 7 years ago

Ah I see what's going on. Python 3 pip is not up to date. That assert statement was very old code. Try installing from the source instead.

dearkafka commented 7 years ago

Thank you, however, during same example (I changed visualize_activation => visualize_class_activation) I've got:

TypeError                                 Traceback (most recent call last)
<ipython-input-9-39a9eb8ab422> in <module>()
     11 for idx in filters:
     12     img = visualize_class_activation(model, layer_idx, filter_indices=idx)
---> 13     img = utils.draw_text(img, str(idx))
     14     vis_images.append(img)

/usr/local/lib/python3.5/dist-packages/vis/utils/utils.py in draw_text(img, text, position, font, font_size, color)
    229     # Don't mutate original image
--> 230     img = Image.fromarray(img)
    231     draw = ImageDraw.Draw(img)
    232     draw.text(position, text, fill=color, font=font)

/usr/local/lib/python3.5/dist-packages/PIL/Image.py in fromarray(obj, mode)
   2292         except KeyError:
   2293             # print(typekey)
-> 2294             raise TypeError("Cannot handle this data type")
   2295     else:
   2296         rawmode = mode

TypeError: Cannot handle this data type
raghakot commented 7 years ago

Probably comment out the utils.draw_text(img, str(idx)) that adds text on image. Is the input 2D image?

dearkafka commented 7 years ago

Thank you, you are right, also I fixed stitched images to handle 3d->2d, and it's great!

raghakot commented 7 years ago

Can you PR the new and improved stitched images :D?