inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
334 stars 36 forks source link

CUDA out of memory issue on long context #279

Open acDante opened 3 weeks ago

acDante commented 3 weeks ago

Question

Hi @gsarti , I find that attribute() function causes CUDA out of memory issue when the input length exceeds about 2500 tokens. I used Llama2-7b / Mistral-7b models to get attribution and chose attention as attribution method. Other gradient-based or perturbation-based attribution methods may consume even more GPU memory and also cause CUDA out of memory issue.

I tested the following code snippet on two A100-80GB. At some point during the attribution process, the GPU memory consumption becomes extremely high (much higher than just runninggenerate()) Increasing the number of GPUs cannot resolve this issue, since the GPU memory consumption on the first GPU will always exceed 80GB.

Is it possible to get attribution from long context with inseq?

import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq

test_data = load_dataset("xsum", split="test")

model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

attr_type="attention"

doc = test_data[797]['document']   # 5450 tokens
input_prompt = f"Summarise the document below: {doc}"
messages = [{
    "role": "user", 
    "content": input_prompt
}]

inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})

Additional context

Stack trace

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[1], [line 36](vscode-notebook-cell:?execution_count=1&line=36)
     [34](vscode-notebook-cell:?execution_count=1&line=34) # input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
     [35](vscode-notebook-cell:?execution_count=1&line=35) inseq_model = inseq.load_model(model, attr_type)
---> [36](vscode-notebook-cell:?execution_count=1&line=36) out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:471, in AttributionModel.attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, skip_special_tokens, generation_args, **kwargs)
    [469](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:469)     logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.")
    [470](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:470)     batch_size = 1
--> [471](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:471) attribution_outputs = attribution_method.prepare_and_attribute(
    [472](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:472)     input_texts,
    [473](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:473)     generated_texts,
    [474](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:474)     batch_size=batch_size,
    [475](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:475)     attr_pos_start=attr_pos_start,
    [476](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:476)     attr_pos_end=attr_pos_end,
    [477](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:477)     show_progress=show_progress,
    [478](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:478)     pretty_progress=pretty_progress,
    [479](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:479)     output_step_attributions=output_step_attributions,
    [480](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:480)     attribute_target=attribute_target,
    [481](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:481)     step_scores=step_scores,
    [482](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:482)     include_eos_baseline=include_eos_baseline,
    [483](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:483)     skip_special_tokens=skip_special_tokens,
    [484](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:484)     attributed_fn=attributed_fn,
    [485](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:485)     attribution_args=attribution_args,
    [486](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:486)     attributed_fn_args=attributed_fn_args,
    [487](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:487)     step_scores_args=step_scores_args,
    [488](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:488) )
    [489](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:489) attribution_output = merge_attributions(attribution_outputs)
    [490](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:490) attribution_output.info["input_texts"] = input_texts

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72, in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
     [69](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:69)         raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
     [71](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72)     out = f(self, *args, **kwargs)
     [73](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:73)     return out if isinstance(out, list) else [out]
     [74](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:243, in FeatureAttribution.prepare_and_attribute(self, sources, targets, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, skip_special_tokens, attributed_fn, attribution_args, attributed_fn_args, step_scores_args)
    [239](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:239) # If prepare_and_attribute was called from AttributionModel.attribute,
    [240](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:240) # attributed_fn is already a Callable. Keep here to allow for usage independently
    [241](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:241) # of AttributionModel.attribute.
    [242](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:242) attributed_fn = self.attribution_model.get_attributed_fn(attributed_fn)
--> [243](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:243) attribution_output = self.attribute(
    [244](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:244)     batch,
    [245](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:245)     attributed_fn=attributed_fn,
    [246](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:246)     attr_pos_start=attr_pos_start,
    [247](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:247)     attr_pos_end=attr_pos_end,
    [248](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:248)     show_progress=show_progress,
    [249](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:249)     pretty_progress=pretty_progress,
    [250](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:250)     output_step_attributions=output_step_attributions,
    [251](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:251)     attribute_target=attribute_target,
    [252](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:252)     step_scores=step_scores,
    [253](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:253)     skip_special_tokens=skip_special_tokens,
    [254](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:254)     attribution_args=attribution_args,
    [255](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:255)     attributed_fn_args=attributed_fn_args,
    [256](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:256)     step_scores_args=step_scores_args,
    [257](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:257) )
    [258](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:258) # Same here, repeated from AttributionModel.attribute
    [259](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:259) # to allow independent usage
    [260](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:260) attribution_output.info["include_eos_baseline"] = include_eos_baseline

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:475, in FeatureAttribution.attribute(self, batch, attributed_fn, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, skip_special_tokens, attribution_args, attributed_fn_args, step_scores_args)
    [473](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:473)     continue
    [474](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:474) tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True)
--> [475](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:475) step_output = self.filtered_attribute_step(
    [476](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:476)     batch[:step],
    [477](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:477)     target_ids=tgt_ids.unsqueeze(1),
    [478](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:478)     attributed_fn=attributed_fn,
    [479](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:479)     target_attention_mask=tgt_mask.unsqueeze(1),
    [480](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:480)     attribute_target=attribute_target,
    [481](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:481)     step_scores=step_scores,
    [482](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:482)     attribution_args=attribution_args,
    [483](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:483)     attributed_fn_args=attributed_fn_args,
    [484](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:484)     step_scores_args=step_scores_args,
    [485](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:485) )
    [486](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:486) # Add batch information to output
    [487](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:487) step_output = self.attribution_model.formatter.enrich_step_output(
    [488](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:488)     self.attribution_model,
    [489](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:489)     step_output,
   (...)
    [494](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:494)     contrast_targets_alignments=contrast_targets_alignments,
    [495](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:495) )

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:622, in FeatureAttribution.filtered_attribute_step(self, batch, target_ids, attributed_fn, target_attention_mask, attribute_target, step_scores, attribution_args, attributed_fn_args, step_scores_args)
    [615](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:615)     output = self.attribution_model.get_forward_output(
    [616](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:616)         batch,
    [617](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:617)         use_embeddings=self.forward_batch_embeds,
    [618](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:618)         output_attentions=self.use_attention_weights,
    [619](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:619)         output_hidden_states=self.use_hidden_states,
    [620](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:620)     )
    [621](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:621) if self.use_attention_weights:
--> [622](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:622)     attentions_dict = self.attribution_model.get_attentions_dict(output)
    [623](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:623)     attribution_args = {**attribution_args, **attentions_dict}
    [624](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/feat/feature_attribution.py:624) if self.use_hidden_states:

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:510, in HuggingfaceDecoderOnlyModel.get_attentions_dict(output)
    [507](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:507) if output.attentions is None:
    [508](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:508)     raise ValueError("Model does not support attribution relying on attention outputs.")
    [509](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:509) return {
--> [510](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:510)     "decoder_self_attentions": torch.stack(output.attentions, dim=1),
    [511](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:511) }

OutOfMemoryError: CUDA out of memory. Tried to allocate 58.56 GiB. GPU

Checklist

gsarti commented 2 weeks ago

Hi @acDante,

Thanks for reporting this. Working with device_map for big models when resources are tight might require a bit of fiddling, and you might want to experiment with manually designing one to minimize occupancy on the first device, e.g.

from collections import OrderedDict

device_map = OrderedDict([('model.embed_tokens', 0),
             ('model.layers.0', 0),
             ('model.layers.1', 0),
             ('model.layers.2', 0),
             ('model.layers.3', 1),
             ('model.layers.4', 1),
             ('model.layers.5', 1),
             ('model.layers.6', 1),
             ('model.layers.7', 1),
             ('model.layers.8', 1),
             ('model.layers.9', 1),
             ('model.layers.10', 1),
             ('model.layers.11', 1),
             ('model.layers.12', 1),
             ('model.layers.13', 1),
             ('model.layers.14', 1),
             ('model.layers.15', 1),
             ('model.layers.16', 1),
             ('model.layers.17', 1),
             ('model.layers.18', 1),
             ('model.layers.19', 1),
             ('model.layers.20', 1),
             ('model.layers.21', 1),
             ('model.layers.22', 1),
             ('model.layers.23', 1),
             ('model.layers.24', 1),
             ('model.layers.25', 1),
             ('model.layers.26', 1),
             ('model.layers.27', 1),
             ('model.layers.28', 1),
             ('model.layers.29', 1),
             ('model.layers.30', 1),
             ('model.layers.31', 1),
             ('model.norm', 1),
             ('lm_head', 1)])

When you say that generate runs without problem, are you specifying output_attentions=True as a kwarg in that case? I suspect that storing all attention scores across all layers will take up a large amount of GPU memory in that case, too.

dxlong2000 commented 1 week ago

Hi @gsarti, any tips like the above for using sequential_integrated_gradients? I used a NVIDIA L40S 46G where I can run the inference for Mistral successfully but can't use inseq for my case. Could you please suggest any modification? Greatly thanks!