Open pexmar opened 7 years ago
So the issue is that gradients are not propagated through Embedding layer which uses tf.gather op.
This means that we need to compute gradients with respect to embedding layer output instead of model input. I have updated the API to allow optional wrt_tensor
argument (https://github.com/raghakot/keras-vis/commit/0c57850db2ba95905f4ff7b38a685bc6d0d38087)
If you pass wrt_tensor = model.layers[1].output
to visualize_saliency
, you should get a heatmap.
Note that I have recently made a change so that various saliency visualizations return raw gradients instead of jet color mapped heatmap.
So, this heatmap should be of shape (400,). You can convert it into a proper heatmap using:
import matplotlib.cm as cm
hmap = np.uint8(cm.jet(grads)[..., :3] * 255)[0]
hmap will have shape (400, 3)
with 3 channels indicating the rgb values for the heatmap.
Traditionally saliency takes max value across all channels (which in case of images is 3). In this case, however, it is taking max across all 60 dimensions (emb_size), which maynot be a good idea. I could imagine np.mean being better in this case. Alternatively, you can directly plot (400, 60)
heatmap as a 2D image. To do that, you can literally copy paste this code (https://github.com/raghakot/keras-vis/blob/master/vis/visualization/saliency.py#L79) and comment out the np.max
part.
Let me know how it goes, and if it seems to make sense at all.
Hi,
directly beforehand: I use the most recent version of keras-vis, keras, theano, and tensorflow.
I really like the idea of keras-vis, but by testing it on an NLP-task I came across with an error:
when I try to run
visualize_saliency(model, layer_idx, filter_indices=0, seed_input=txt_instances.x[0])
on my trained model, then I get a theano.gradient.DisconnectedInputError error in this line. Do you have an idea why?The exact error message is the following:
I already tried to make the embedding layer trainable (as suggested on slack), that did not solve the problem. I prepared a Gist, you will need the english word2vec file that is provided by google (https://code.google.com/archive/p/word2vec/):
https://gist.github.com/pexmar/8900cb65f4970bd911ebc81206c9b131
Thanks in advance for your help and thanks for this great library ;-)