inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
375 stars 36 forks source link

Utilizing Inseq #232

Closed SVC04 closed 1 year ago

SVC04 commented 1 year ago

How can i utilize inseq for the problem of text generation containing long piece of text.

Hello, Thank you for providing such a great library for interpretability of generation tasks. I am trying to use inseq to explain output of text summarization generated from hugging face transformer model. Since the article used for summarization is too long the saliency heatmap visualization would be too long and computationally expensive as well. Any suggestion on this that how can i utilize inseq for summarization interpretability.

gsarti commented 1 year ago

Hi @seemavishal,

Thanks for reaching out! Do you have a specific use-case in mind for saliency heatmaps for the summarization task? I guess that you might want to use this to verify which parts of the inputs are attended by the model during summarization. In this case, I might suggest:

Here is an example applied to headline generation:

import inseq

model = "Michau/t5-base-en-generate-headline"

article = (
    "Very early yesterday morning, the United States President Donald Trump reported he and his wife First Lady Melania Trump tested positive for COVID-19. "
    "Officials said the Trumps' 14-year-old son Barron tested negative as did First Family and Senior Advisors Jared Kushner and Ivanka Trump. "
    "Trump took to social media, posting at 12:54 am local time (0454 UTC) on Twitter, \"Tonight, [Melania] and I tested positive for COVID-19. "
    "We will begin our quarantine and recovery process immediately. We will get through this TOGETHER!\" Yesterday afternoon Marine One landed on "
    "the White House's South Lawn flying Trump to Walter Reed National Military Medical Center (WRNMMC) in Bethesda, Maryland. "
    "Reports said both were showing \"mild symptoms\". Senior administration officials were tested as people were informed of the positive test. "
    "Senior advisor Hope Hicks had tested positive on Thursday. Presidential physician Sean Conley issued a statement saying Trump has been given zinc, "
    "vitamin D, Pepcid and a daily Aspirin. Conley also gave a single dose of the experimental polyclonal antibodies drug from Regeneron Pharmaceuticals. "
    "According to official statements, Trump, now operating from the WRNMMC, is to continue performing his duties as president during a 14-day quarantine. "
    "In the event of Trump becoming incapacitated, Vice President Mike Pence could take over the duties of president via the 25th Amendment of the US Constitution. "
    "The Pence family all tested negative as of yesterday and there were no changes regarding Pence's campaign events"
)

model = inseq.load_model(model, "attention")

out = model.attribute(article)

# Use periods as end-of-sentence indices for aggregation
ends = [i + 1 for i, t in enumerate(out[0].source) if t.token == "."] + [len(out[0].source) - 1]
starts = [0] + [i + 1 for i, t in enumerate(out[0].source) if t.token == "."]
source_spans = list(zip(starts, ends))

out.aggregate("spans", source_spans=source_spans).show()

image

Alternatively, you can also aggregate the generated tokens in the target side to obtain an aggregated per-source-sentence importance score for the whole generation:

# Added after the code above
target_spans=[(0, len(out[0].target) - 1)]
out.aggregate("spans", source_spans=source_spans, target_spans=target_spans).show()

image

Hope it helps!

SVC04 commented 1 year ago

Dear @gsarti thanks alot for your answer. The code you have shared above throws an error while executing the code line given below. ends = [i + 1 for i, t in enumerate(out[0].source) if t.token == "."] + [len(out[0].source) - 1]

The errir message is as follows. TypeError: 'FeatureAttributionOutput' object is not subscriptable

SVC04 commented 1 year ago

Dear Inseq-Team

Thank you so much for responding to my query. I am looking out at summarization of long pdf documents and then see what part of an input document the model ha given an importance to..

I tried executing the script shared above. However after executing ends variable it gives following error.

TypeError: 'FeatureAttributionOutput' object is not subscriptable looking forward to your response.

On Tue, Nov 7, 2023 at 8:03 PM Gabriele Sarti @.***> wrote:

Hi @seemavishal https://github.com/seemavishal,

Thanks for reaching out! Do you have a specific use-case in mind for saliency heatmaps for the summarization task? I guess that you might want to use this to verify which parts of the inputs are attended by the model during summarization. In this case, I might suggest:

Here is an example applied to headline generation:

import inseq

model = "Michau/t5-base-en-generate-headline" article = ( "Very early yesterday morning, the United States President Donald Trump reported he and his wife First Lady Melania Trump tested positive for COVID-19. " "Officials said the Trumps' 14-year-old son Barron tested negative as did First Family and Senior Advisors Jared Kushner and Ivanka Trump. " "Trump took to social media, posting at 12:54 am local time (0454 UTC) on Twitter, \"Tonight, [Melania] and I tested positive for COVID-19. " "We will begin our quarantine and recovery process immediately. We will get through this TOGETHER!\" Yesterday afternoon Marine One landed on " "the White House's South Lawn flying Trump to Walter Reed National Military Medical Center (WRNMMC) in Bethesda, Maryland. " "Reports said both were showing \"mild symptoms\". Senior administration officials were tested as people were informed of the positive test. " "Senior advisor Hope Hicks had tested positive on Thursday. Presidential physician Sean Conley issued a statement saying Trump has been given zinc, " "vitamin D, Pepcid and a daily Aspirin. Conley also gave a single dose of the experimental polyclonal antibodies drug from Regeneron Pharmaceuticals. " "According to official statements, Trump, now operating from the WRNMMC, is to continue performing his duties as president during a 14-day quarantine. " "In the event of Trump becoming incapacitated, Vice President Mike Pence could take over the duties of president via the 25th Amendment of the US Constitution. " "The Pence family all tested negative as of yesterday and there were no changes regarding Pence's campaign events" ) model = inseq.load_model(model, "attention") out = model.attribute(article)

Use periods as end-of-sentence indices for aggregationends = [i + 1 for i, t in enumerate(out[0].source) if t.token == "."] + [len(out[0].source) - 1]starts = [0] + [i + 1 for i, t in enumerate(out[0].source) if t.token == "."]source_spans = list(zip(starts, ends))

out.aggregate("spans", source_spans=source_spans).show()

[image: image] https://user-images.githubusercontent.com/16674069/281102423-c5a84435-145f-4f04-ac5e-d395d0e30572.png

Alternatively, you can also aggregate the generated tokens in the target side to obtain an aggregated per-source-sentence importance score for the whole generation:

Added after the code abovetarget_spans=[(0, len(out[0].target) - 1)]out.aggregate("spans", source_spans=source_spans, target_spans=target_spans).show()

[image: image] https://user-images.githubusercontent.com/16674069/281103807-57135355-2808-46e9-901d-3c03e2fd7fbc.png

Hope it helps!

— Reply to this email directly, view it on GitHub https://github.com/inseq-team/inseq/issues/232#issuecomment-1799031587, or unsubscribe https://github.com/notifications/unsubscribe-auth/AU7Z435W7VEBY3RIC4A7MQTYDJLTXAVCNFSM6AAAAAA64WCEM6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOJZGAZTCNJYG4 . You are receiving this because you were mentioned.Message ID: @.***>

gsarti commented 1 year ago

Hi @seemavishal, you should install the latest version of inseq using pip install git+https://github.com/inseq-team/inseq.git@main. The error you report is probably due to an older version of the library. Hope it helps!

SVC04 commented 1 year ago

Dear @gsarti Thanks alot for your response. It worked in colab. However when i shifted my code on Google clod to create the visulization on large dataset Google cloud throws an error while importing Inseq.

Error Message module 'jax' has no attribute 'Array'

gsarti commented 1 year ago

It could be due to the jaxtyping dependency when jax is pre-installed (which is the case on GCP, afaik). Maybe @carschno can you have a look into this? @seemavishal could you provide us with the code you ran and the error trace, please?

SVC04 commented 1 year ago

@gsarti, just after using these two lines of code below on GCP it throws the error.

pip install git+https://github.com/inseq-team/inseq.git@main import inseq

The error trace is below.


AttributeError                            Traceback (most recent call last)
Cell In [1], line 1
----> 1 import inseq

File /usr/local/lib/python3.9/dist-packages/inseq/__init__.py:3
      1 """Interpretability for Sequence Generation Models 🔍."""
----> 3 from .attr import list_feature_attribution_methods, list_step_functions, register_step_function
      4 from .data import (
      5     FeatureAttributionOutput,
      6     list_aggregation_functions,
   (...)
      9     show_attributions,
     10 )
     11 from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config

File /usr/local/lib/python3.9/dist-packages/inseq/attr/__init__.py:1
----> 1 from .feat import FeatureAttribution, extract_args, list_feature_attribution_methods
      2 from .step_functions import (
      3     STEP_SCORES_MAP,
      4     StepFunctionArgs,
      5     list_step_functions,
      6     register_step_function,
      7 )
      9 __all__ = [
     10     "FeatureAttribution",
     11     "list_feature_attribution_methods",
   (...)
     16     "StepFunctionArgs",
     17 ]

File /usr/local/lib/python3.9/dist-packages/inseq/attr/feat/__init__.py:1
----> 1 from .attribution_utils import extract_args, join_token_ids
      2 from .feature_attribution import FeatureAttribution, list_feature_attribution_methods
      3 from .gradient_attribution import (
      4     DeepLiftAttribution,
      5     DiscretizedIntegratedGradientsAttribution,
   (...)
     14     SequentialIntegratedGradientsAttribution,
     15 )

File /usr/local/lib/python3.9/dist-packages/inseq/attr/feat/attribution_utils.py:7
      3 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
      5 import torch
----> 7 from ...utils import extract_signature_args, get_aligned_idx
      8 from ...utils.typing import (
      9     OneOrMoreAttributionSequences,
     10     OneOrMoreIdSequences,
   (...)
     14     TokenWithId,
     15 )
     16 from ..step_functions import get_step_scores_args

File /usr/local/lib/python3.9/dist-packages/inseq/utils/__init__.py:1
----> 1 from .alignment_utils import get_adjusted_alignments, get_aligned_idx
      2 from .argparse import InseqArgumentParser
      3 from .cache import INSEQ_ARTIFACTS_CACHE, INSEQ_HOME_CACHE, cache_results

File /usr/local/lib/python3.9/dist-packages/inseq/utils/alignment_utils.py:12
      9 import torch
     10 from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
---> 12 from .misc import clean_tokens
     14 logger = logging.getLogger(__name__)
     16 ALIGN_MODEL_ID = "sentence-transformers/LaBSE"

File /usr/local/lib/python3.9/dist-packages/inseq/utils/misc.py:21
     18 from torch import Tensor
     20 from .errors import LengthMismatchError
---> 21 from .typing import TextInput, TokenWithId
     23 logger = logging.getLogger(__name__)
     26 @contextmanager
     27 def optional(condition, context_manager, alternative_fn=None, **alternative_fn_kwargs):

File /usr/local/lib/python3.9/dist-packages/inseq/utils/typing.py:5
      2 from typing import Optional, Sequence, Tuple, Union
      4 import torch
----> 5 from jaxtyping import Float, Float32, Int64
      6 from transformers import PreTrainedModel
      8 TextInput = Union[str, Sequence[str]]

File /usr/local/lib/python3.9/dist-packages/jaxtyping/__init__.py:25
     22 import warnings
     24 # First import some things as normal
---> 25 from ._array_types import (
     26     AbstractArray as AbstractArray,
     27     AbstractDtype as AbstractDtype,
     28     get_array_name_format as get_array_name_format,
     29     has_jax,
     30     set_array_name_format as set_array_name_format,
     31 )
     32 from ._decorator import jaxtyped as jaxtyped
     33 from ._import_hook import install_import_hook as install_import_hook

File /usr/local/lib/python3.9/dist-packages/jaxtyping/_array_types.py:680
    678 PRNGKeyArray = Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]]
    679 Scalar = Shaped[jax.Array, ""]
--> 680 ScalarLike = Shaped[jax.typing.ArrayLike, ""]

AttributeError: module 'jax' has no attribute 'typing'
gsarti commented 1 year ago

Have you tried to pip install -U jax and see if the problem persists?

SVC04 commented 1 year ago

@gsarti Tried pip install -U jax. Then it suggested to update jaxlib. So used pip install -U jaxlib. And now import works fine. Thanks alot for your prompt response and providing this amazing open source library.

gsarti commented 1 year ago

Thanks for following up on that! Best of luck with your use-cases!

gsarti commented 12 months ago

Can you try using the same model to summarize using only the transformers library? Do you also get a kernel crash? The issue might be due to the text not fitting the GPU RAM of your machine, rather than an issue with the library

SVC04 commented 12 months ago

Yes the library works fine. It's the issue with GPU RAM.

Thank you 😊

On Sun, Nov 12, 2023, 1:19 PM Gabriele Sarti @.***> wrote:

Can you try using the same model to summarize using only the transformers library? Do you also get a kernel crash? The issue might be due to the text not fitting the GPU RAM of your machine, rather than an issue with the library

— Reply to this email directly, view it on GitHub https://github.com/inseq-team/inseq/issues/232#issuecomment-1807067461, or unsubscribe https://github.com/notifications/unsubscribe-auth/AU7Z432Z377YJJLWUDTAD5TYECIDRAVCNFSM6AAAAAA64WCEM6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMBXGA3DONBWGE . You are receiving this because you were mentioned.Message ID: @.***>

SVC04 commented 9 months ago

@gsarti

Using the example code that you provided above i am not able to reproduce the results. The code suddenly stopped giving the result.

import inseq

model = "Michau/t5-base-en-generate-headline"

article = (
    "Very early yesterday morning, the United States President Donald Trump reported he and his wife First Lady Melania Trump tested positive for COVID-19. "
    "Officials said the Trumps' 14-year-old son Barron tested negative as did First Family and Senior Advisors Jared Kushner and Ivanka Trump. "
    "Trump took to social media, posting at 12:54 am local time (0454 UTC) on Twitter, \"Tonight, [Melania] and I tested positive for COVID-19. "
    "We will begin our quarantine and recovery process immediately. We will get through this TOGETHER!\" Yesterday afternoon Marine One landed on "
    "the White House's South Lawn flying Trump to Walter Reed National Military Medical Center (WRNMMC) in Bethesda, Maryland. "
    "Reports said both were showing \"mild symptoms\". Senior administration officials were tested as people were informed of the positive test. "
    "Senior advisor Hope Hicks had tested positive on Thursday. Presidential physician Sean Conley issued a statement saying Trump has been given zinc, "
    "vitamin D, Pepcid and a daily Aspirin. Conley also gave a single dose of the experimental polyclonal antibodies drug from Regeneron Pharmaceuticals. "
    "According to official statements, Trump, now operating from the WRNMMC, is to continue performing his duties as president during a 14-day quarantine. "
    "In the event of Trump becoming incapacitated, Vice President Mike Pence could take over the duties of president via the 25th Amendment of the US Constitution. "
    "The Pence family all tested negative as of yesterday and there were no changes regarding Pence's campaign events"
)

model = inseq.load_model(model, "attention")

out = model.attribute(article)

It throws the error below.

Attributing with attention...: 100%|██████████| 20/20 [00:51<00:00,  2.73s/it]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-13-d11f043a48e5>](https://localhost:8080/#) in <cell line: 1>()
----> 1 out = model.attribute(article)

5 frames
[/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py](https://localhost:8080/#) in attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, generation_args, **kwargs)
    443             logger.info("Batched attribution currently not supported for LIME. Using batch size of 1.")
    444             batch_size = 1
--> 445         attribution_outputs = attribution_method.prepare_and_attribute(
    446             input_texts,
    447             generated_texts,

[/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py](https://localhost:8080/#) in batched_wrapper(self, batch_size, *args, **kwargs)
     70 
     71         if batch_size is None:
---> 72             out = f(self, *args, **kwargs)
     73             return out if isinstance(out, list) else [out]
     74         batched_args = [get_batched(batch_size, arg) for arg in args]

[/usr/local/lib/python3.10/dist-packages/inseq/attr/feat/feature_attribution.py](https://localhost:8080/#) in prepare_and_attribute(self, sources, targets, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, attribution_args, attributed_fn_args, step_scores_args)
    235         # of AttributionModel.attribute.
    236         attributed_fn = self.attribution_model.get_attributed_fn(attributed_fn)
--> 237         attribution_output = self.attribute(
    238             batch,
    239             attributed_fn=attributed_fn,

[/usr/local/lib/python3.10/dist-packages/inseq/attr/feat/feature_attribution.py](https://localhost:8080/#) in attribute(self, batch, attributed_fn, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, attribution_args, attributed_fn_args, step_scores_args)
    473         batch.detach().to("cpu")
    474         out = FeatureAttributionOutput(
--> 475             sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions(
    476                 attributions=attribution_outputs,
    477                 tokenized_target_sentences=target_tokens_with_ids,

[/usr/local/lib/python3.10/dist-packages/inseq/data/attribution.py](https://localhost:8080/#) in from_step_attributions(cls, attributions, tokenized_target_sentences, pad_id, has_bos_token, attr_pos_end)
    274                     out_seq_scores = [attr.sequence_scores[seq_score_name][i, ...] for i in range(num_sequences)]
    275                 else:
--> 276                     out_seq_scores = get_sequences_from_batched_steps(
    277                         [att.sequence_scores[seq_score_name] for att in attributions], padding_dims=[2], stack_dim=3
    278                     )

[/usr/local/lib/python3.10/dist-packages/inseq/utils/torch_utils.py](https://localhost:8080/#) in get_sequences_from_batched_steps(bsteps, padding_dims, stack_dim)
    219                 padded_bstep = F.pad(bstep, pad=pad_shape, mode="constant", value=float("nan"))
    220                 bsteps[bstep_idx] = padded_bstep
--> 221     sequences = torch.stack(bsteps, dim=stack_dim).split(1, dim=0)
    222     return [seq.squeeze(0) for seq in sequences]
    223 

RuntimeError: stack expects each tensor to be equal size, but got [1, 1, 12, 12] at entry 0 and [1, 2, 12, 12] at entry 1
gsarti commented 9 months ago

Hi @SVC04, a bug was introduced in the PR #245 merged earlier today. It should be fixed now if you reinstall from main!

SVC04 commented 9 months ago

Hi @seemavishal,

Thanks for reaching out! Do you have a specific use-case in mind for saliency heatmaps for the summarization task? I guess that you might want to use this to verify which parts of the inputs are attended by the model during summarization. In this case, I might suggest:

  • Using the ContiguousSpanAggregator (spans) class to aggregate source tokens at sentence level, to obtain a per-source-sentence importance score for every generated token that is more easily interpretable.
  • Prefer faster attribution methods like saliency and attention to make the attribution quicker. While these methods might not always faithfully reflect model importance, they can be useful approximations if efficiency is required.

Here is an example applied to headline generation:

import inseq

model = "Michau/t5-base-en-generate-headline"

article = (
    "Very early yesterday morning, the United States President Donald Trump reported he and his wife First Lady Melania Trump tested positive for COVID-19. "
    "Officials said the Trumps' 14-year-old son Barron tested negative as did First Family and Senior Advisors Jared Kushner and Ivanka Trump. "
    "Trump took to social media, posting at 12:54 am local time (0454 UTC) on Twitter, \"Tonight, [Melania] and I tested positive for COVID-19. "
    "We will begin our quarantine and recovery process immediately. We will get through this TOGETHER!\" Yesterday afternoon Marine One landed on "
    "the White House's South Lawn flying Trump to Walter Reed National Military Medical Center (WRNMMC) in Bethesda, Maryland. "
    "Reports said both were showing \"mild symptoms\". Senior administration officials were tested as people were informed of the positive test. "
    "Senior advisor Hope Hicks had tested positive on Thursday. Presidential physician Sean Conley issued a statement saying Trump has been given zinc, "
    "vitamin D, Pepcid and a daily Aspirin. Conley also gave a single dose of the experimental polyclonal antibodies drug from Regeneron Pharmaceuticals. "
    "According to official statements, Trump, now operating from the WRNMMC, is to continue performing his duties as president during a 14-day quarantine. "
    "In the event of Trump becoming incapacitated, Vice President Mike Pence could take over the duties of president via the 25th Amendment of the US Constitution. "
    "The Pence family all tested negative as of yesterday and there were no changes regarding Pence's campaign events"
)

model = inseq.load_model(model, "attention")

out = model.attribute(article)

# Use periods as end-of-sentence indices for aggregation
ends = [i + 1 for i, t in enumerate(out[0].source) if t.token == "."] + [len(out[0].source) - 1]
starts = [0] + [i + 1 for i, t in enumerate(out[0].source) if t.token == "."]
source_spans = list(zip(starts, ends))

out.aggregate("spans", source_spans=source_spans).show()

image

Alternatively, you can also aggregate the generated tokens in the target side to obtain an aggregated per-source-sentence importance score for the whole generation:

# Added after the code above
target_spans=[(0, len(out[0].target) - 1)]
out.aggregate("spans", source_spans=source_spans, target_spans=target_spans).show()

image

Hope it helps!

Hello Gabriele

A small doubt on this.

While i change the model to something as BART and in the input article wherever instead of sentence ending with a period if it ends with quotation mark that precedes the period (".) then the sentence boundary is not detected. I tried to add the space after quotation added token like [end of quote] but the sentence boundary is not detected. and instead of attribution of single sentence the two sentence attribution score is generated.

Any suggestion on how can i handle such cases to generate per sentence score.?

gsarti commented 9 months ago

Hi @SVC04,

Have you tried simply changing the condition if t.token == "." in the code above with if t.token.endswith(".")? Just tried it in the example above and it works for me!

Hope it helps!

SVC04 commented 9 months ago

That works perfectly. I tried all the complex regex alternatives and missed using this simple solution.

Thank you.

SVC04 commented 2 months ago

@gsarti Not able to reproduce this. It throws following error.

RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.