pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.72k stars 476 forks source link

Captum for semantic text similarity using sentence transformers #1102

Open MARUD84 opened 1 year ago

MARUD84 commented 1 year ago

❓ Questions and Help

Thank your for the great library and your efforts towards explainable AI.

I am writing to find out how the library can be used to better understand a cosine similarity metrics obtained from computing semantic textual similarity using a pre trained sentence transformer.

aobo-y commented 1 year ago

Hi, could you elaborate more, like any issues you encountered? Any pseudo code to show us what you are trying to achieve?

If you are trying to understand which tokens in your text contribute to the cosine similarity, you can choose the attribution algorithm and do sth like below:

def forward_func(sentence_input_1, sentence_input_2):
  sentence_vector_1 = sentence_transformer(sentence_input_1)
  sentence_vector_2 = sentence_transformer(sentence_input_2)
  return cosine_similarity(sentence_vector_1, sentence_vector_2)

fa = FeatureAblation(forward_func)
fa.attribute((input_1, input_2))
MARUD84 commented 1 year ago

@aobo-y Thank you so much for taking the time to get back to me on this. I am trying to find out which tokens are contributing to a given similarity score between two sentences. So it would be good to know token x in sentence 1 and token y in sentence 2 are contributing to the high/low similarity score. I hope this clarifies the problem. If you have any additional insights on how i might solve this using captum would be very grateful for your help. I tried using the above approach but am encountering the following error. TypeError: len() of a 0-d tensor

aobo-y commented 1 year ago

yup, that's my imagined use case. The above pseudo code gives a valid high-level flow of how to craft your forward_func/model.

The error you encountered seems just to be an issue of your inputs. If you cannot solve it, you can share your code & error stack trace so we may help debug.

MARUD84 commented 1 year ago

Hello @aobo-y thank you for taking the time to reply. This is the code i am using

from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
from transformers import AutoModel, AutoTokenizer
import torch

pretrained_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
model = AutoModel.from_pretrained(pretrained_name)

def max_pool_embeddings_with_attention_mask(embeddings, attention_mask):
    input_mask_expanded = (
        attention_mask
        .unsqueeze(-1)
        .expand(embeddings.size())
        .float()
    )
    embeddings[input_mask_expanded == 0] = -1e9
    return torch.max(embeddings, 1)[0]

def mean_pool_embeddings_with_attention_mask(embeddings, attention_mask):
    input_mask_expanded = (
        attention_mask
        .unsqueeze(-1)
        .expand(embeddings.size())
        .float()
    )
    return (torch.sum(embeddings * input_mask_expanded, axis=1)
            / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9))

def pool_embeddings_with_attention_mask(embeddings, attention_mask, mode="mean"):
    if mode == "cls":
        CLS_TOKEN_POSITION = 0
        return embeddings[:, CLS_TOKEN_POSITION]
        # ^ NOTE: If the [CLS] token is not the first token of the sequence, 
        #         then this obviously doesn't make any sense
    if mode == "max":
        return max_pool_embeddings_with_attention_mask(
            embeddings, attention_mask
        )
    if mode == "mean":
        return mean_pool_embeddings_with_attention_mask(
            embeddings, attention_mask
        )
    raise ValueError(
        f"Pooling mode '{mode}' is not supported."
    )

def construct_baseline_and_input(response, query):
    response_input_ids, query_input_ids = tokenizer(
        [response, query],
        padding=True,
        add_special_tokens=True
    ).input_ids
    def replace_non_special_token_with_padding(token_id):
        if token_id in tokenizer.all_special_ids:
            return token_id
        return tokenizer.pad_token_id
    baseline_input_ids = list(
        map(replace_non_special_token_with_padding, response_input_ids)
    )
    return (
        torch.tensor([baseline_input_ids]),
        torch.tensor([response_input_ids]),
        torch.tensor([query_input_ids])
    )

def construct_attention_mask(input_ids):
    return (torch.where(input_ids == tokenizer.pad_token_id, 0, 1))

def predict(
    responses,
    queries,
    pooling_mode="mean",
    attention_mask_for_responses=None,
    attention_mask_for_queries=None,
):   
    output_for_responses = model(
        input_ids=responses, attention_mask=attention_mask_for_responses
    )
    output_for_queries = model(
        input_ids=queries, attention_mask=attention_mask_for_queries
    )
    response_embeddings = pool_embeddings_with_attention_mask(
        embeddings=output_for_responses.last_hidden_state,
        attention_mask=attention_mask_for_responses,
        mode=pooling_mode
    )
    query_embeddings = pool_embeddings_with_attention_mask(
        embeddings=output_for_queries.last_hidden_state,
        attention_mask=attention_mask_for_queries,
        mode=pooling_mode
    )
    cosine_similarities = torch.nn.functional.cosine_similarity(
        response_embeddings, query_embeddings, dim=1
    )
    return cosine_similarities

sample_query = "The fox jumped over the sleepy dog"
sample_response = "The dog ran away from the fox"
sample_label = 1

baseline_input_ids, response_input_ids, query_input_ids  = \
    construct_baseline_and_input(sample_response, sample_query)

response_indices = response_input_ids[0].detach().tolist()
response_tokens = tokenizer.convert_ids_to_tokens(response_indices)

attention_mask_for_responses = construct_attention_mask(response_input_ids)
attention_mask_for_queries = construct_attention_mask(query_input_ids)

def measure_cosine_sim(response_input_ids, query_input_ids):
    cosine_similarities = predict(
        response_input_ids,
        query_input_ids,
        pooling_mode="mean",
        attention_mask_for_responses=attention_mask_for_responses,
        attention_mask_for_queries=attention_mask_for_queries
    )
    return cosine_similarities[0].item()

cos_sim_response = measure_cosine_sim(response_input_ids, query_input_ids)
print("Cosine similarity:", cos_sim_response)

lig = LayerIntegratedGradients(predict, model.embeddings, multiply_by_inputs=False)

pooling_mode = "mean"
attributions, delta = lig.attribute(
    inputs=response_input_ids,
    baselines=baseline_input_ids,
    additional_forward_args=(
        query_input_ids,
        pooling_mode,
        attention_mask_for_responses,
        attention_mask_for_queries
    ),
    n_steps=50,
    internal_batch_size=2,
    return_convergence_delta=True
)

delta, attributions.sum(dim=-1).squeeze(0)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

summarized_attributions = summarize_attributions(attributions)

predicted_similarity_score = cos_sim_response
predicted_label = int(predicted_similarity_score > 0.5)

vis = viz.VisualizationDataRecord(
    word_attributions=summarized_attributions,
    pred_prob=predicted_similarity_score,
    pred_class=predicted_label,
    true_class=sample_label,
    attr_class=sample_label,
    attr_score=summarized_attributions.sum(),       
    raw_input_ids=response_tokens,
    convergence_score=delta
)

viz.visualize_text([vis])`

`

However i am not getting the desired output. I would like both sentences listed and would like to get an indication of whether each token is positively or negatively contributing to the sentence level cosine similarity. Do you have any practical suggestions on how i might achieve this. Would be grateful for any insights on this. Thanking you in advance.

aobo-y commented 1 year ago

@MARUD84 The code looks great to me. What is the exact issue you have? Is it the error you mentioned "TypeError: len() of a 0-d tensor"? If so, you may want to attach the error trace.

to know token x in sentence 1 and token y in sentence 2 are contributing to the high/low similarity score

Also, as you mentioned you need attributions of both sentences, you can pass both sentences as inputs to captum

lig.attribute(
    inputs=(response_input_ids, query_input_ids),
    baselines=baseline_input_ids,
    additional_forward_args=(
        pooling_mode,
        attention_mask_for_responses,
        attention_mask_for_queries
    ),
    n_steps=50,
    internal_batch_size=2,
    return_convergence_delta=True
)

At last, I think this tutorial may cover everything you need https://captum.ai/tutorials/Bert_SQUAD_Interpret

MARUD84 commented 1 year ago

Hello @aobo-y Thanks once again for your response.

This is the result i get with the above changes. I actually would like to have both sentences in the same line and look at the tokens positively and negatively contributing to the cosine similarity score.

I had a look at the tutorial but could not find the answer there. If you have any further insights, i would be very grateful for your help.

Screenshot 2023-02-13 at 10 21 53