inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
344 stars 37 forks source link

Expected all tensors to be on the same device multi-gpu #276

Open saxenarohit opened 2 months ago

saxenarohit commented 2 months ago

Hi,

It seems this bug is still there.

Heres the testing code

from transformers import AutoModelForCausalLM,AutoTokenizer
import inseq
import torch

model_name = "lmsys/vicuna-7b-v1.1"
attr_type="saliency"

model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map="auto",cache_dir="/data/huggingface_cache/")
tokenizer = AutoTokenizer.from_pretrained(model_name,cache_dir="/data/huggingface_cache/")
model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
model.eval()

input_text = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''

inseq_model = inseq.load_model(model, attr_type, tokenizer=tokenizer)
input_tokens = tokenizer.encode_plus(input_text,return_tensors="pt").to(model.device)
inseq_output_text = inseq_model.generate(input_tokens,do_sample=False,max_new_tokens=100,skip_special_tokens=True)
out = inseq_model.attribute(input_texts=input_text, generated_texts = inseq_output_text[0])
out.show()

Error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I also tried other way and ran into same error.


import inseq

model_name = "lmsys/vicuna-7b-v1.1"
attr_type="input_x_gradient"
input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})

Please let me know how to fix this. Thanks.

gsarti commented 2 months ago

Hi @saxenarohit , thanks for reporting this? Could you provide:

Thanks!

saxenarohit commented 2 months ago

thanks versions

torch 2.2.0a0+81ea7a4
transformers 4.41.0
inseq 0.6.0

Stacktrace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], [line 19](vscode-notebook-cell:?execution_count=1&line=19)
     [17](vscode-notebook-cell:?execution_count=1&line=17) inseq_model = inseq.load_model(model, attr_type, tokenizer=tokenizer)
     [18](vscode-notebook-cell:?execution_count=1&line=18) input_tokens = tokenizer.encode_plus(input_text,return_tensors="pt").to(model.device)
---> [19](vscode-notebook-cell:?execution_count=1&line=19) inseq_output_text = inseq_model.generate(input_tokens,do_sample=False,max_new_tokens=100,skip_special_tokens=True)
     [20](vscode-notebook-cell:?execution_count=1&line=20) out = inseq_model.attribute(input_texts=input_text, generated_texts = inseq_output_text[0])
     [21](vscode-notebook-cell:?execution_count=1&line=21) out.show()

File [/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12), in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
     [10](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:10)     was_hooked = True
     [11](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:11)     self.attribution_method.unhook()
---> [12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
     [13](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:13) if was_hooked:
     [14](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:14)     self.attribution_method.hook()

File [/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72), in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
     [69](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:69)         raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
     [71](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72)     out = f(self, *args, **kwargs)
     [73](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:73)     return out if isinstance(out, list) else [out]
     [74](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]

File [/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221), in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
    [219](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:219)     inputs = self.encode(inputs)
    [220](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:220) inputs = inputs.to(self.device)
--> [221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221) generation_out = self.model.generate(
    [222](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:222)     inputs=inputs.input_ids,
    [223](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:223)     return_dict_in_generate=True,
    [224](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:224)     **kwargs,
    [225](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:225) )
    [226](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:226) sequences = generation_out.sequences
    [227](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:227) if output_generated_only and not self.is_encoder_decoder:

File [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   [1728](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1728)     input_ids, model_kwargs = self._expand_inputs_for_generation(
   [1729](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1729)         input_ids=input_ids,
   [1730](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1730)         expand_size=generation_config.num_return_sequences,
   [1731](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1731)         is_encoder_decoder=self.config.is_encoder_decoder,
   [1732](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1732)         **model_kwargs,
   [1733](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1733)     )
   [1735](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1735)     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1736](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736)     result = self._sample(
   [1737](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1737)         input_ids,
   [1738](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1738)         logits_processor=prepared_logits_processor,
   [1739](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1739)         logits_warper=prepared_logits_warper,
   [1740](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1740)         stopping_criteria=prepared_stopping_criteria,
   [1741](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1741)         generation_config=generation_config,
   [1742](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1742)         synced_gpus=synced_gpus,
   [1743](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1743)         streamer=streamer,
   [1744](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1744)         **model_kwargs,
   [1745](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1745)     )
   [1747](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1747) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   [1748](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1748)     # 11. prepare logits warper
   [1749](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1749)     prepared_logits_warper = (
   [1750](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1750)         self._get_logits_warper(generation_config) if generation_config.do_sample else None
   [1751](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1751)     )

File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   [2372](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2372) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2374](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2374) # forward pass to get next token
-> [2375](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375) outputs = self(
   [2376](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2376)     **model_inputs,
   [2377](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2377)     return_dict=True,
   [2378](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2378)     output_attentions=output_attentions,
   [2379](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2379)     output_hidden_states=output_hidden_states,
   [2380](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2380) )
   [2382](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2382) if synced_gpus and this_peer_finished:
   [2383](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2383)     continue  # don't waste resources running the code we don't need

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   [1161](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1161) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   [1163](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1163) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164) outputs = self.model(
   [1165](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1165)     input_ids=input_ids,
   [1166](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1166)     attention_mask=attention_mask,
   [1167](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1167)     position_ids=position_ids,
   [1168](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1168)     past_key_values=past_key_values,
   [1169](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1169)     inputs_embeds=inputs_embeds,
   [1170](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1170)     use_cache=use_cache,
   [1171](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1171)     output_attentions=output_attentions,
   [1172](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1172)     output_hidden_states=output_hidden_states,
   [1173](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1173)     return_dict=return_dict,
   [1174](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1174)     cache_position=cache_position,
   [1175](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1175) )
   [1177](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1177) hidden_states = outputs[0]
   [1178](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1178) if self.config.pretraining_tp > 1:

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    [957](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:957)     layer_outputs = self._gradient_checkpointing_func(
    [958](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:958)         decoder_layer.__call__,
    [959](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:959)         hidden_states,
   (...)
    [965](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:965)         cache_position,
    [966](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:966)     )
    [967](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:967) else:
--> [968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968)     layer_outputs = decoder_layer(
    [969](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:969)         hidden_states,
    [970](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:970)         attention_mask=causal_mask,
    [971](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:971)         position_ids=position_ids,
    [972](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:972)         past_key_value=past_key_values,
    [973](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:973)         output_attentions=output_attentions,
    [974](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:974)         use_cache=use_cache,
    [975](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:975)         cache_position=cache_position,
    [976](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:976)     )
    [978](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:978) hidden_states = layer_outputs[0]
    [980](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:980) if use_cache:

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    [694](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:694) """
    [695](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:695) Args:
    [696](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:696)     hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
   (...)
    [706](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:706)     past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    [707](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:707) """
    [708](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:708) residual = hidden_states
--> [710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710) hidden_states = self.input_layernorm(hidden_states)
    [712](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:712) # Self Attention
    [713](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:713) hidden_states, self_attn_weights, present_key_value = self.self_attn(
    [714](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:714)     hidden_states=hidden_states,
    [715](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:715)     attention_mask=attention_mask,
   (...)
    [720](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:720)     cache_position=cache_position,
    [721](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:721) )

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89), in LlamaRMSNorm.forward(self, hidden_states)
     [87](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:87) variance = hidden_states.pow(2).mean(-1, keepdim=True)
     [88](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:88) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
---> [89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89) return self.weight * hidden_states.to(input_dtype)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
gsarti commented 1 month ago

Hi @saxenarohit , sorry for the delay! I see you are passing the encoded ids directly to generate, could you try again the second example:


import inseq

model_name = "lmsys/vicuna-7b-v1.1"
attr_type="input_x_gradient"
input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})

and report the stack trace for that one? Also, can you confirm that checking in on main with pip install git+https://github.com/inseq-team/inseq.git@main does not solve the issue? Thank you in advance!

saxenarohit commented 1 month ago

Hi, Thanks. Here is the stack trace for the second example

`RuntimeError                              Traceback (most recent call last)
Cell In[3], [line 16](notebook-cell:?execution_count=3&line=16)
     [14](notebook-cell:?execution_count=3&line=14) input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
     [15](notebook-cell:?execution_count=3&line=15) inseq_model = inseq.load_model(model, attr_type)
---> [16](notebook-cell:?execution_count=3&line=16) out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})

File [/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424), in AttributionModel.attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, generation_args, **kwargs)
    [422](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:422)         decoder_input = self.encode(generated_texts, as_targets=True)
    [423](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:423)         generation_args["decoder_input_ids"] = decoder_input.input_ids
--> [424](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424)     generated_texts = self.generate(
    [425](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:425)         encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
    [426](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:426)     )
    [427](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:427) elif generation_args:
    [428](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:428)     logger.warning(
    [429](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:429)         f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
    [430](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:430)     )

File [/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12), in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
     [10](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:10)     was_hooked = True
     [11](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:11)     self.attribution_method.unhook()
---> [12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
     [13](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:13) if was_hooked:
     [14](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:14)     self.attribution_method.hook()

File [/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72), in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
     [69](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:69)         raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
     [71](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72)     out = f(self, *args, **kwargs)
     [73](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:73)     return out if isinstance(out, list) else [out]
     [74](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]

File [/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221), in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
    [219](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:219)     inputs = self.encode(inputs)
    [220](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:220) inputs = inputs.to(self.device)
--> [221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221) generation_out = self.model.generate(
    [222](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:222)     inputs=inputs.input_ids,
    [223](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:223)     return_dict_in_generate=True,
    [224](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:224)     **kwargs,
    [225](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:225) )
    [226](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:226) sequences = generation_out.sequences
    [227](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:227) if output_generated_only and not self.is_encoder_decoder:

File [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   [1750](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1750)     input_ids, model_kwargs = self._expand_inputs_for_generation(
   [1751](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1751)         input_ids=input_ids,
   [1752](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1752)         expand_size=generation_config.num_return_sequences,
   [1753](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1753)         is_encoder_decoder=self.config.is_encoder_decoder,
   [1754](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1754)         **model_kwargs,
   [1755](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1755)     )
   [1757](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1757)     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1758](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758)     result = self._sample(
   [1759](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1759)         input_ids,
   [1760](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1760)         logits_processor=prepared_logits_processor,
   [1761](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1761)         logits_warper=prepared_logits_warper,
   [1762](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1762)         stopping_criteria=prepared_stopping_criteria,
   [1763](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1763)         generation_config=generation_config,
   [1764](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1764)         synced_gpus=synced_gpus,
   [1765](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1765)         streamer=streamer,
   [1766](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1766)         **model_kwargs,
   [1767](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1767)     )
   [1769](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1769) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   [1770](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1770)     # 11. prepare logits warper
   [1771](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1771)     prepared_logits_warper = (
   [1772](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1772)         self._get_logits_warper(generation_config) if generation_config.do_sample else None
   [1773](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1773)     )

File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   [2394](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2394) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2396](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2396) # forward pass to get next token
-> [2397](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397) outputs = self(
   [2398](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2398)     **model_inputs,
   [2399](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2399)     return_dict=True,
   [2400](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2400)     output_attentions=output_attentions,
   [2401](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2401)     output_hidden_states=output_hidden_states,
   [2402](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2402) )
   [2404](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2404) if synced_gpus and this_peer_finished:
   [2405](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2405)     continue  # don't waste resources running the code we don't need

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   [1161](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1161) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   [1163](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1163) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164) outputs = self.model(
   [1165](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1165)     input_ids=input_ids,
   [1166](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1166)     attention_mask=attention_mask,
   [1167](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1167)     position_ids=position_ids,
   [1168](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1168)     past_key_values=past_key_values,
   [1169](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1169)     inputs_embeds=inputs_embeds,
   [1170](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1170)     use_cache=use_cache,
   [1171](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1171)     output_attentions=output_attentions,
   [1172](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1172)     output_hidden_states=output_hidden_states,
   [1173](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1173)     return_dict=return_dict,
   [1174](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1174)     cache_position=cache_position,
   [1175](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1175) )
   [1177](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1177) hidden_states = outputs[0]
   [1178](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1178) if self.config.pretraining_tp > 1:

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    [957](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:957)     layer_outputs = self._gradient_checkpointing_func(
    [958](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:958)         decoder_layer.__call__,
    [959](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:959)         hidden_states,
   (...)
    [965](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:965)         cache_position,
    [966](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:966)     )
    [967](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:967) else:
--> [968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968)     layer_outputs = decoder_layer(
    [969](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:969)         hidden_states,
    [970](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:970)         attention_mask=causal_mask,
    [971](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:971)         position_ids=position_ids,
    [972](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:972)         past_key_value=past_key_values,
    [973](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:973)         output_attentions=output_attentions,
    [974](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:974)         use_cache=use_cache,
    [975](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:975)         cache_position=cache_position,
    [976](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:976)     )
    [978](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:978) hidden_states = layer_outputs[0]
    [980](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:980) if use_cache:

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    [694](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:694) """
    [695](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:695) Args:
    [696](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:696)     hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
   (...)
    [706](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:706)     past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    [707](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:707) """
    [708](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:708) residual = hidden_states
--> [710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710) hidden_states = self.input_layernorm(hidden_states)
    [712](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:712) # Self Attention
    [713](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:713) hidden_states, self_attn_weights, present_key_value = self.self_attn(
    [714](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:714)     hidden_states=hidden_states,
    [715](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:715)     attention_mask=attention_mask,
   (...)
    [720](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:720)     cache_position=cache_position,
    [721](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:721) )

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
   [1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
   [1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
   [1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517)         or _global_backward_pre_hooks or _global_backward_hooks
   [1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519)     return forward_call(*args, **kwargs)
   [1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
   [1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522)     result = None

File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    [164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164)         output = module._old_forward(*args, **kwargs)
    [165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166)     output = module._old_forward(*args, **kwargs)
    [167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89), in LlamaRMSNorm.forward(self, hidden_states)
     [87](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:87) variance = hidden_states.pow(2).mean(-1, keepdim=True)
     [88](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:88) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

---> [89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89) return self.weight * hidden_states.to(input_dtype)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!`

Let me check in meantime installing from source

gsarti commented 1 month ago

From the looks of it, I'd say it's an issue with the way in which accelerate is used. Are you able to run model.generate in transformers with the same model without issues (i.e. loading the transformers AutoModelForCausalLM and running generate from it)?

acDante commented 1 month ago

Hi @gsarti , I have similar issue with device_map="balanced_low_0". I can confirm thatmodel.generate() can work in this setting. Here is my code snippet for replicating this issue:

import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq

# Load model and test data
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="balanced_low_0",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

test_data = load_dataset("xsum", split="test")
doc = test_data[1]['document']
input_prompt = f"Summarise the document below: {doc}"
messages = [{
    "role": "user",
    "content": input_prompt
}]

prompt = tokenizer.apply_chat_template(messages,
                                       tokenize=False,
                                       add_generation_prompt=True)

attr_type = "attention"
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
out.show()

I got the following error message:

RuntimeError                              Traceback (most recent call last)
Cell In[2], [line 34](vscode-notebook-cell:?execution_count=2&line=34)
     [32](vscode-notebook-cell:?execution_count=2&line=32) attr_type="attention"
     [33](vscode-notebook-cell:?execution_count=2&line=33) inseq_model = inseq.load_model(model, attr_type)
---> [34](vscode-notebook-cell:?execution_count=2&line=34) out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
     [35](vscode-notebook-cell:?execution_count=2&line=35) out.show()

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:438, in AttributionModel.attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, skip_special_tokens, generation_args, **kwargs)
    [434](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:434)         decoder_input = self.encode(
    [435](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:435)             generated_texts, as_targets=True, add_special_tokens=not skip_special_tokens
    [436](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:436)         )
    [437](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:437)         generation_args["decoder_input_ids"] = decoder_input.input_ids
--> [438](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:438)     generated_texts = self.generate(
    [439](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:439)         encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
    [440](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:440)     )
    [441](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:441) elif generation_args:
    [442](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:442)     logger.warning(
    [443](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:443)         f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
    [444](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:444)     )

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:12, in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
     [10](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:10)     was_hooked = True
     [11](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:11)     self.attribution_method.unhook()
---> [12](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
     [13](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:13) if was_hooked:
     [14](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:14)     self.attribution_method.hook()

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72, in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
     [69](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:69)         raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
     [71](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72)     out = f(self, *args, **kwargs)
     [73](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:73)     return out if isinstance(out, list) else [out]
     [74](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:233, in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
    [231](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:231)     inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens)
    [232](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:232) inputs = inputs.to(self.device)
--> [233](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:233) generation_out = self.model.generate(
    [234](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:234)     inputs=inputs.input_ids,
    [235](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:235)     return_dict_in_generate=True,
    [236](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:236)     **kwargs,
    [237](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:237) )
    [238](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:238) sequences = generation_out.sequences
    [239](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:239) if output_generated_only and not self.is_encoder_decoder:

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   [1750](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1750)     input_ids, model_kwargs = self._expand_inputs_for_generation(
   [1751](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1751)         input_ids=input_ids,
   [1752](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1752)         expand_size=generation_config.num_return_sequences,
   [1753](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1753)         is_encoder_decoder=self.config.is_encoder_decoder,
   [1754](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1754)         **model_kwargs,
   [1755](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1755)     )
   [1757](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1757)     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1758](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758)     result = self._sample(
   [1759](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1759)         input_ids,
   [1760](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1760)         logits_processor=prepared_logits_processor,
   [1761](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1761)         logits_warper=prepared_logits_warper,
   [1762](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1762)         stopping_criteria=prepared_stopping_criteria,
   [1763](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1763)         generation_config=generation_config,
   [1764](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1764)         synced_gpus=synced_gpus,
   [1765](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1765)         streamer=streamer,
   [1766](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1766)         **model_kwargs,
   [1767](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1767)     )
   [1769](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1769) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   [1770](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1770)     # 11. prepare logits warper
   [1771](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1771)     prepared_logits_warper = (
   [1772](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1772)         self._get_logits_warper(generation_config) if generation_config.do_sample else None
   [1773](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1773)     )

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   [2394](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2394) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2396](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2396) # forward pass to get next token
-> [2397](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397) outputs = self(
   [2398](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2398)     **model_inputs,
   [2399](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2399)     return_dict=True,
   [2400](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2400)     output_attentions=output_attentions,
   [2401](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2401)     output_hidden_states=output_hidden_states,
   [2402](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2402) )
   [2404](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2404) if synced_gpus and this_peer_finished:
   [2405](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2405)     continue  # don't waste resources running the code we don't need

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544)     result = None

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1139, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   [1136](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1136) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   [1138](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1138) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1139](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1139) outputs = self.model(
   [1140](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1140)     input_ids=input_ids,
   [1141](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1141)     attention_mask=attention_mask,
   [1142](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1142)     position_ids=position_ids,
   [1143](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1143)     past_key_values=past_key_values,
   [1144](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1144)     inputs_embeds=inputs_embeds,
   [1145](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1145)     use_cache=use_cache,
   [1146](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1146)     output_attentions=output_attentions,
   [1147](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1147)     output_hidden_states=output_hidden_states,
   [1148](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1148)     return_dict=return_dict,
   [1149](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1149) )
   [1151](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1151) hidden_states = outputs[0]
   [1152](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1152) logits = self.lm_head(hidden_states)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544)     result = None

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:968, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    [965](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:965)     position_ids = position_ids.view(-1, seq_length).long()
    [967](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:967) if inputs_embeds is None:
--> [968](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:968)     inputs_embeds = self.embed_tokens(input_ids)
    [970](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:970) if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
    [971](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:971)     is_padding_right = attention_mask[:, -1].sum().item() != batch_size

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544)     result = None

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
    [162](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162) def forward(self, input: Tensor) -> Tensor:
--> [163](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163)     return F.embedding(
    [164](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:164)         input, self.weight, self.padding_idx, self.max_norm,
    [165](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:165)         self.norm_type, self.scale_grad_by_freq, self.sparse)

File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   [2258](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2258)     # Note [embedding_renorm set_grad_enabled]
   [2259](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2259)     # XXX: equivalent to
   [2260](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2260)     # with torch.no_grad():
   [2261](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2261)     #   torch.embedding_renorm_
   [2262](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2262)     # remove once script supports set_grad_enabled
   [2263](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2263)     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> [2264](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2264) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
gsarti commented 1 month ago

Agreed for balanced_low_0, but I really doubt that it is an issue with Inseq. For device_map=auto, the following code works for me on a 2x RTX A4000 16GB setup with pip install git+https://github.com/inseq-team/inseq.git@fix-multidevice installed:

import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq
import logging

logging.basicConfig(level=logging.DEBUG)

# Load model and test data
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

test_data = load_dataset("xsum", split="test")
doc = test_data[1]['document']
input_prompt = f"Summarise the document below: {doc}"
messages = [{
    "role": "user",
    "content": input_prompt
}]

prompt = tokenizer.apply_chat_template(messages,
                                       tokenize=False,
                                       add_generation_prompt=True)

attr_type="attention"
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(prompt, generation_args={"do_sample": False, "max_new_tokens": 100}, skip_special_tokens=True)