ankurtaly / Integrated-Gradients

Attributing predictions made by the Inception network using the Integrated Gradients method
600 stars 98 forks source link

Can you provide an example of text like Sentiment classification? #15

Open lemon234071 opened 5 years ago

lemon234071 commented 5 years ago

Specifically, how to deal with word embedding....

ankurtaly commented 5 years ago

Sure, we will soon add a text example.

The key idea is to attribute from the output to the embedding tensor (as the network is only differentiable up to the embedding tensor).

We will receive an attribution for each dimension of each token embedding. Then for each token we sum the attributions along all embedding dimensions to obtain a token level attribution.

Here is some example code:

def ig_text(inp, label, t_embedding, t_label, t_grad, baseline=None, steps=20):
  # Args:
  # - inp: Input tokens (or token ids) whose prediction (for the provided label) must be explained
  # - label: Prediction label 
  # - t_embedding: Embedding tensor
  # - t_grad: Tensor computing gradients of prediction w.r.t. the embedding tensor
  # - t_label: Placeholder tensor specifying the prediction label for which gradients must be computed
  if baseline is None:
    baseline = 0*inp
  embs = sess.run(t_embedding, {t_inp: [inp, baseline]})  # <batch, num_tokens, emd_dims>
  inp_emb = embs[0, :, :]
  baseline_emb = embs[1, :, :]
  scaled_embs = [baseline_emb + (float(i)/steps)*(inp_emb-baseline_emb) for i in range(0, steps+1)]
  feed[t_embedding] = scaled_embs
  feed[t_label] = label
  grads, scores = sess.run([t_grad, y_probs], feed_dict=feed)  # shapes: <steps+1, inp_emb.shape>, <steps, num_labels>    
  ig = (inp_emb-baseline_emb)*np.average(grads[1:,:,:], axis=0)  # shape: <inp_emb.shape>
  token_ig = np.sum(ig, axis=-1)  # shape: <num_tokens>
  return token_ig
lemon234071 commented 5 years ago

Thinks! Thank you very much for quick reply! It's very helpful. I just did not know how to obtain a token level attribution from embedding attributions, didn't figure out the principle. Best wishes.