Closed davidefiocco closed 4 years ago
Thank you @davidefiocco! Glad that you find it useful. From the error it looks like the inputs that you are trying to attribute to aren't used in the forward pass.
Actually in custom_forward
: why are you creating a new tensor: torch.tensor([torch.softmax(preds, dim = 1)[0][1]], requires_grad = True)
I don't think that it is necessary:
torch.softmax(preds, dim = 1)[0][1]].unsqueeze(0)
should do it too.
Hi @NarineK for the helpful reply!
Indeed that return torch.softmax(preds, dim = 1)[0][1].unsqueeze(0)
solved!
(you have an extra ] though!)
Here's a few more changes that I tried starting from the SQUAD tutorial and adapt it to the binary task:
additional_forward_arg
tuple though...):attributions, delta = lig.attribute(inputs=input_ids,
baselines=ref_input_ids,
additional_forward_args=(token_type_ids, position_ids, attention_mask, 0), # revise this
return_convergence_delta=True)
I then just have one attribution sum
attributions_sum = summarize_attributions(attributions)
score = predict(input_ids, token_type_ids=token_type_ids, \
position_ids=position_ids, \
attention_mask=attention_mask)
score_vis = viz.VisualizationDataRecord(
attributions_sum,
torch.max(torch.softmax(score[0], dim=0)),
torch.argmax(score[0]), # revise this, not sure about it
torch.argmax(score[0]), # revise this, not sure about it
text,
attributions_sum.sum(),
all_tokens,
delta)
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])
This allowed me to display some output:
but I am not fully convinced that all of the above is OK (the interpretation is a bit tricky to digest, as I have finetuned BERT on the GLUE CoLA task), so if anybody has some feedback it's much appreciated!
PS: You find my current notebook at this gist: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5
@davidefiocco, the 0
or 1
additional forward arg indices where specifically for the SQUAD model because that model returns a tuple of 2 tensors, one is for the prediction probability for the start index and the other one is the prediction probability for the last index.
In the binary classification case my forward function looks something like this:
def custom_forward(*inputs):
out = model(*inputs)[0]
return out
Thanks again @NarineK for your kind replies :)
I edited the notebook in https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 so to use a custom_forward
more similar to yours:
def custom_forward(inputs):
out = model(inputs)[0][0]
return out
so my call to lig.attributes
also gets simplified:
attributions, delta = lig.attribute(inputs=input_ids,
baselines=ref_input_ids,
return_convergence_delta=True)
Not sure this is ready for PR(ime) time and is working correctly, but I am glad to share it here in case that's helpful for somebody else. Let me know if you have additional feedback!
@davidefiocco , looks good to me! It terms of custom_function, you can play with it and see what is it actually returning. I used it for binary classification but it can be different from model to model
Thanks you, I will close this issue then. Should any follow-up arise, I will open a new issue (I am not fully convinced my solution works well/may want to explore other interpretation schemes beyond LayerIntegratedGradients
. Thanks!
Hi @NarineK and captum team, thanks for all the great work on interpretability with PyTorch.
As others here (see https://github.com/pytorch/captum/issues/150, https://github.com/pytorch/captum/issues/249), I am trying to interpret a BERT classifier finetuned on a binary classification task, using the
transformers
library from HuggingFace. Indeed, I haveI am not being great at doing this, starting from the SQUAD example https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb
So far, I left almost everything else untouched and redefined
which I call with
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
and a custom forward method that readswhich I use in
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)
.When calling
lig.attribute
(as in the tutorial), I getCan you help me debug the above? I guess I am messing something up with the
custom_forward
method, and maybe alsoconstruct_input_ref_pair
... or more.I am happy to post a working solution once done with this!