pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.91k stars 494 forks source link

How to intepret BERT for SequenceClassification? #303

Closed davidefiocco closed 4 years ago

davidefiocco commented 4 years ago

Hi @NarineK and captum team, thanks for all the great work on interpretability with PyTorch.

As others here (see https://github.com/pytorch/captum/issues/150, https://github.com/pytorch/captum/issues/249), I am trying to interpret a BERT classifier finetuned on a binary classification task, using the transformers library from HuggingFace. Indeed, I have

model = BertForSequenceClassification.from_pretrained('finetuned-bert-base-cased')

I am not being great at doing this, starting from the SQUAD example https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb

So far, I left almost everything else untouched and redefined

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

which I call with input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id) and a custom forward method that reads

def custom_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    outputs = predict(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
    preds = outputs[0]
   #preds is like
   #tensor([[-1.9723,  2.2183]], grad_fn=<AddmmBackward>)
    return torch.tensor([torch.softmax(preds, dim = 1)[0][1]], requires_grad = True)

which I use in lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings).

When calling lig.attribute (as in the tutorial), I get

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Can you help me debug the above? I guess I am messing something up with the custom_forward method, and maybe also construct_input_ref_pair... or more.

I am happy to post a working solution once done with this!

NarineK commented 4 years ago

Thank you @davidefiocco! Glad that you find it useful. From the error it looks like the inputs that you are trying to attribute to aren't used in the forward pass.

Actually in custom_forward: why are you creating a new tensor: torch.tensor([torch.softmax(preds, dim = 1)[0][1]], requires_grad = True) I don't think that it is necessary:

torch.softmax(preds, dim = 1)[0][1]].unsqueeze(0)

should do it too.

davidefiocco commented 4 years ago

Hi @NarineK for the helpful reply!

Indeed that return torch.softmax(preds, dim = 1)[0][1].unsqueeze(0) solved! (you have an extra ] though!)

Here's a few more changes that I tried starting from the SQUAD tutorial and adapt it to the binary task:

attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0), # revise this
                                  return_convergence_delta=True)

I then just have one attribution sum

attributions_sum = summarize_attributions(attributions)

score = predict(input_ids, token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.max(torch.softmax(score[0], dim=0)),
                        torch.argmax(score[0]),  # revise this, not sure about it
                        torch.argmax(score[0]),  # revise this, not sure about it
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

This allowed me to display some output:

image

but I am not fully convinced that all of the above is OK (the interpretation is a bit tricky to digest, as I have finetuned BERT on the GLUE CoLA task), so if anybody has some feedback it's much appreciated!

PS: You find my current notebook at this gist: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5

NarineK commented 4 years ago

@davidefiocco, the 0 or 1 additional forward arg indices where specifically for the SQUAD model because that model returns a tuple of 2 tensors, one is for the prediction probability for the start index and the other one is the prediction probability for the last index. In the binary classification case my forward function looks something like this:

def custom_forward(*inputs):
    out = model(*inputs)[0]
    return out
davidefiocco commented 4 years ago

Thanks again @NarineK for your kind replies :)

I edited the notebook in https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 so to use a custom_forward more similar to yours:

def custom_forward(inputs):
    out = model(inputs)[0][0]
    return out

so my call to lig.attributes also gets simplified:

attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  return_convergence_delta=True)

Not sure this is ready for PR(ime) time and is working correctly, but I am glad to share it here in case that's helpful for somebody else. Let me know if you have additional feedback!

NarineK commented 4 years ago

@davidefiocco , looks good to me! It terms of custom_function, you can play with it and see what is it actually returning. I used it for binary classification but it can be different from model to model

davidefiocco commented 4 years ago

Thanks you, I will close this issue then. Should any follow-up arise, I will open a new issue (I am not fully convinced my solution works well/may want to explore other interpretation schemes beyond LayerIntegratedGradients. Thanks!