Closed saxenarohit closed 3 months ago
Hi @saxenarohit , thanks for reporting this? Could you provide:
inseq
, torch
and transformers
Thanks!
thanks versions
torch 2.2.0a0+81ea7a4
transformers 4.41.0
inseq 0.6.0
Stacktrace
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], [line 19](vscode-notebook-cell:?execution_count=1&line=19)
[17](vscode-notebook-cell:?execution_count=1&line=17) inseq_model = inseq.load_model(model, attr_type, tokenizer=tokenizer)
[18](vscode-notebook-cell:?execution_count=1&line=18) input_tokens = tokenizer.encode_plus(input_text,return_tensors="pt").to(model.device)
---> [19](vscode-notebook-cell:?execution_count=1&line=19) inseq_output_text = inseq_model.generate(input_tokens,do_sample=False,max_new_tokens=100,skip_special_tokens=True)
[20](vscode-notebook-cell:?execution_count=1&line=20) out = inseq_model.attribute(input_texts=input_text, generated_texts = inseq_output_text[0])
[21](vscode-notebook-cell:?execution_count=1&line=21) out.show()
File [/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12), in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
[10](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:10) was_hooked = True
[11](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:11) self.attribution_method.unhook()
---> [12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
[13](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:13) if was_hooked:
[14](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:14) self.attribution_method.hook()
File [/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72), in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
[69](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:69) raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
[71](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72) out = f(self, *args, **kwargs)
[73](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:73) return out if isinstance(out, list) else [out]
[74](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]
File [/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221), in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
[219](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:219) inputs = self.encode(inputs)
[220](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:220) inputs = inputs.to(self.device)
--> [221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221) generation_out = self.model.generate(
[222](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:222) inputs=inputs.input_ids,
[223](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:223) return_dict_in_generate=True,
[224](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:224) **kwargs,
[225](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:225) )
[226](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:226) sequences = generation_out.sequences
[227](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:227) if output_generated_only and not self.is_encoder_decoder:
File [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
[1728](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1728) input_ids, model_kwargs = self._expand_inputs_for_generation(
[1729](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1729) input_ids=input_ids,
[1730](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1730) expand_size=generation_config.num_return_sequences,
[1731](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1731) is_encoder_decoder=self.config.is_encoder_decoder,
[1732](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1732) **model_kwargs,
[1733](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1733) )
[1735](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1735) # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1736](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1736) result = self._sample(
[1737](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1737) input_ids,
[1738](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1738) logits_processor=prepared_logits_processor,
[1739](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1739) logits_warper=prepared_logits_warper,
[1740](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1740) stopping_criteria=prepared_stopping_criteria,
[1741](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1741) generation_config=generation_config,
[1742](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1742) synced_gpus=synced_gpus,
[1743](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1743) streamer=streamer,
[1744](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1744) **model_kwargs,
[1745](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1745) )
[1747](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1747) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
[1748](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1748) # 11. prepare logits warper
[1749](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1749) prepared_logits_warper = (
[1750](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1750) self._get_logits_warper(generation_config) if generation_config.do_sample else None
[1751](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1751) )
File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
[2372](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2372) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
[2374](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2374) # forward pass to get next token
-> [2375](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2375) outputs = self(
[2376](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2376) **model_inputs,
[2377](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2377) return_dict=True,
[2378](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2378) output_attentions=output_attentions,
[2379](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2379) output_hidden_states=output_hidden_states,
[2380](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2380) )
[2382](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2382) if synced_gpus and this_peer_finished:
[2383](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2383) continue # don't waste resources running the code we don't need
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[1161](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1161) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
[1163](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1163) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164) outputs = self.model(
[1165](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1165) input_ids=input_ids,
[1166](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1166) attention_mask=attention_mask,
[1167](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1167) position_ids=position_ids,
[1168](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1168) past_key_values=past_key_values,
[1169](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1169) inputs_embeds=inputs_embeds,
[1170](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1170) use_cache=use_cache,
[1171](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1171) output_attentions=output_attentions,
[1172](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1172) output_hidden_states=output_hidden_states,
[1173](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1173) return_dict=return_dict,
[1174](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1174) cache_position=cache_position,
[1175](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1175) )
[1177](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1177) hidden_states = outputs[0]
[1178](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1178) if self.config.pretraining_tp > 1:
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[957](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:957) layer_outputs = self._gradient_checkpointing_func(
[958](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:958) decoder_layer.__call__,
[959](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:959) hidden_states,
(...)
[965](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:965) cache_position,
[966](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:966) )
[967](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:967) else:
--> [968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968) layer_outputs = decoder_layer(
[969](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:969) hidden_states,
[970](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:970) attention_mask=causal_mask,
[971](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:971) position_ids=position_ids,
[972](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:972) past_key_value=past_key_values,
[973](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:973) output_attentions=output_attentions,
[974](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:974) use_cache=use_cache,
[975](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:975) cache_position=cache_position,
[976](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:976) )
[978](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:978) hidden_states = layer_outputs[0]
[980](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:980) if use_cache:
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
[694](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:694) """
[695](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:695) Args:
[696](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:696) hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
(...)
[706](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:706) past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
[707](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:707) """
[708](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:708) residual = hidden_states
--> [710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710) hidden_states = self.input_layernorm(hidden_states)
[712](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:712) # Self Attention
[713](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:713) hidden_states, self_attn_weights, present_key_value = self.self_attn(
[714](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:714) hidden_states=hidden_states,
[715](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:715) attention_mask=attention_mask,
(...)
[720](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:720) cache_position=cache_position,
[721](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:721) )
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89), in LlamaRMSNorm.forward(self, hidden_states)
[87](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:87) variance = hidden_states.pow(2).mean(-1, keepdim=True)
[88](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:88) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
---> [89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89) return self.weight * hidden_states.to(input_dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
Hi @saxenarohit , sorry for the delay! I see you are passing the encoded ids directly to generate, could you try again the second example:
import inseq
model_name = "lmsys/vicuna-7b-v1.1"
attr_type="input_x_gradient"
input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
and report the stack trace for that one? Also, can you confirm that checking in on main with pip install git+https://github.com/inseq-team/inseq.git@main
does not solve the issue? Thank you in advance!
Hi, Thanks. Here is the stack trace for the second example
`RuntimeError Traceback (most recent call last)
Cell In[3], [line 16](notebook-cell:?execution_count=3&line=16)
[14](notebook-cell:?execution_count=3&line=14) input_prompt = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: A 34-year-old man was arrested in connection with an outstanding warrant and is expected to appear at Glasgow Sherriff Court on Monday. A 15-year-old male was arrested for offensive behaviour and resisting arrest and a 16-year-old male was arrested for offensive behaviour. Three men were arrested outside the stadium in connection with assault. The men, aged 29, 28 and 27, and all from Glasgow, are expected to appear at Aberdeen Sherriff Court on Monday. Police said the two teenagers will be reported to the relevant authorities. Match Commander Supt Innes Walker said: "The vast majority of fans from both football clubs followed the advice given and conducted themselves appropriately. "The policing operation was assisted by specialist resources including the horses, the dog unit and roads policing and we appreciate the support of the overwhelming majority of fans and members of the public in allowing the Friday night game to be enjoyed and pass safely." Celtic won the match 3-1\nSummarize the provided document. The summary should be extremely short. ASSISTANT:'''
[15](notebook-cell:?execution_count=3&line=15) inseq_model = inseq.load_model(model, attr_type)
---> [16](notebook-cell:?execution_count=3&line=16) out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
File [/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424), in AttributionModel.attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, generation_args, **kwargs)
[422](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:422) decoder_input = self.encode(generated_texts, as_targets=True)
[423](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:423) generation_args["decoder_input_ids"] = decoder_input.input_ids
--> [424](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:424) generated_texts = self.generate(
[425](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:425) encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
[426](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:426) )
[427](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:427) elif generation_args:
[428](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:428) logger.warning(
[429](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:429) f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
[430](/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py:430) )
File [/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12), in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
[10](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:10) was_hooked = True
[11](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:11) self.attribution_method.unhook()
---> [12](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
[13](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:13) if was_hooked:
[14](/usr/local/lib/python3.10/dist-packages/inseq/models/model_decorators.py:14) self.attribution_method.hook()
File [/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72), in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
[69](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:69) raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
[71](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:72) out = f(self, *args, **kwargs)
[73](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:73) return out if isinstance(out, list) else [out]
[74](/usr/local/lib/python3.10/dist-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]
File [/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221), in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
[219](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:219) inputs = self.encode(inputs)
[220](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:220) inputs = inputs.to(self.device)
--> [221](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:221) generation_out = self.model.generate(
[222](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:222) inputs=inputs.input_ids,
[223](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:223) return_dict_in_generate=True,
[224](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:224) **kwargs,
[225](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:225) )
[226](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:226) sequences = generation_out.sequences
[227](/usr/local/lib/python3.10/dist-packages/inseq/models/huggingface_model.py:227) if output_generated_only and not self.is_encoder_decoder:
File [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
[1750](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1750) input_ids, model_kwargs = self._expand_inputs_for_generation(
[1751](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1751) input_ids=input_ids,
[1752](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1752) expand_size=generation_config.num_return_sequences,
[1753](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1753) is_encoder_decoder=self.config.is_encoder_decoder,
[1754](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1754) **model_kwargs,
[1755](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1755) )
[1757](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1757) # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1758](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1758) result = self._sample(
[1759](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1759) input_ids,
[1760](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1760) logits_processor=prepared_logits_processor,
[1761](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1761) logits_warper=prepared_logits_warper,
[1762](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1762) stopping_criteria=prepared_stopping_criteria,
[1763](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1763) generation_config=generation_config,
[1764](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1764) synced_gpus=synced_gpus,
[1765](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1765) streamer=streamer,
[1766](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1766) **model_kwargs,
[1767](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1767) )
[1769](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1769) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
[1770](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1770) # 11. prepare logits warper
[1771](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1771) prepared_logits_warper = (
[1772](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1772) self._get_logits_warper(generation_config) if generation_config.do_sample else None
[1773](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1773) )
File [/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
[2394](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2394) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
[2396](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2396) # forward pass to get next token
-> [2397](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2397) outputs = self(
[2398](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2398) **model_inputs,
[2399](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2399) return_dict=True,
[2400](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2400) output_attentions=output_attentions,
[2401](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2401) output_hidden_states=output_hidden_states,
[2402](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2402) )
[2404](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2404) if synced_gpus and this_peer_finished:
[2405](/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2405) continue # don't waste resources running the code we don't need
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[1161](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1161) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
[1163](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1163) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1164](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1164) outputs = self.model(
[1165](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1165) input_ids=input_ids,
[1166](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1166) attention_mask=attention_mask,
[1167](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1167) position_ids=position_ids,
[1168](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1168) past_key_values=past_key_values,
[1169](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1169) inputs_embeds=inputs_embeds,
[1170](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1170) use_cache=use_cache,
[1171](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1171) output_attentions=output_attentions,
[1172](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1172) output_hidden_states=output_hidden_states,
[1173](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1173) return_dict=return_dict,
[1174](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1174) cache_position=cache_position,
[1175](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1175) )
[1177](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1177) hidden_states = outputs[0]
[1178](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1178) if self.config.pretraining_tp > 1:
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[957](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:957) layer_outputs = self._gradient_checkpointing_func(
[958](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:958) decoder_layer.__call__,
[959](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:959) hidden_states,
(...)
[965](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:965) cache_position,
[966](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:966) )
[967](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:967) else:
--> [968](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:968) layer_outputs = decoder_layer(
[969](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:969) hidden_states,
[970](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:970) attention_mask=causal_mask,
[971](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:971) position_ids=position_ids,
[972](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:972) past_key_value=past_key_values,
[973](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:973) output_attentions=output_attentions,
[974](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:974) use_cache=use_cache,
[975](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:975) cache_position=cache_position,
[976](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:976) )
[978](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:978) hidden_states = layer_outputs[0]
[980](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:980) if use_cache:
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
[694](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:694) """
[695](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:695) Args:
[696](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:696) hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
(...)
[706](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:706) past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
[707](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:707) """
[708](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:708) residual = hidden_states
--> [710](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:710) hidden_states = self.input_layernorm(hidden_states)
[712](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:712) # Self Attention
[713](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:713) hidden_states, self_attn_weights, present_key_value = self.self_attn(
[714](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:714) hidden_states=hidden_states,
[715](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:715) attention_mask=attention_mask,
(...)
[720](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:720) cache_position=cache_position,
[721](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:721) )
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510), in Module._wrapped_call_impl(self, *args, **kwargs)
[1508](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1509](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1509) else:
-> [1510](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1510) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519), in Module._call_impl(self, *args, **kwargs)
[1514](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1514) # If we don't have any hooks, we want to skip the rest of the logic in
[1515](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1515) # this function, and just call forward.
[1516](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1516) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1517](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1517) or _global_backward_pre_hooks or _global_backward_hooks
[1518](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1519](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1519) return forward_call(*args, **kwargs)
[1521](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1521) try:
[1522](/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1522) result = None
File [/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
[164](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:164) output = module._old_forward(*args, **kwargs)
[165](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165) else:
--> [166](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166) output = module._old_forward(*args, **kwargs)
[167](/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:167) return module._hf_hook.post_forward(module, output)
File [/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89), in LlamaRMSNorm.forward(self, hidden_states)
[87](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:87) variance = hidden_states.pow(2).mean(-1, keepdim=True)
[88](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:88) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
---> [89](/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:89) return self.weight * hidden_states.to(input_dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!`
Let me check in meantime installing from source
From the looks of it, I'd say it's an issue with the way in which accelerate
is used. Are you able to run model.generate
in transformers
with the same model without issues (i.e. loading the transformers AutoModelForCausalLM
and running generate from it)?
Hi @gsarti , I have similar issue with device_map="balanced_low_0"
. I can confirm thatmodel.generate()
can work in this setting. Here is my code snippet for replicating this issue:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq
# Load model and test data
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="balanced_low_0",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
test_data = load_dataset("xsum", split="test")
doc = test_data[1]['document']
input_prompt = f"Summarise the document below: {doc}"
messages = [{
"role": "user",
"content": input_prompt
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
attr_type = "attention"
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
out.show()
I got the following error message:
RuntimeError Traceback (most recent call last)
Cell In[2], [line 34](vscode-notebook-cell:?execution_count=2&line=34)
[32](vscode-notebook-cell:?execution_count=2&line=32) attr_type="attention"
[33](vscode-notebook-cell:?execution_count=2&line=33) inseq_model = inseq.load_model(model, attr_type)
---> [34](vscode-notebook-cell:?execution_count=2&line=34) out = inseq_model.attribute(input_prompt, generation_args={"do_sample": False, "max_new_tokens": 100})
[35](vscode-notebook-cell:?execution_count=2&line=35) out.show()
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:438, in AttributionModel.attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, skip_special_tokens, generation_args, **kwargs)
[434](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:434) decoder_input = self.encode(
[435](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:435) generated_texts, as_targets=True, add_special_tokens=not skip_special_tokens
[436](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:436) )
[437](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:437) generation_args["decoder_input_ids"] = decoder_input.input_ids
--> [438](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:438) generated_texts = self.generate(
[439](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:439) encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
[440](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:440) )
[441](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:441) elif generation_args:
[442](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:442) logger.warning(
[443](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:443) f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
[444](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/attribution_model.py:444) )
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:12, in unhooked.<locals>.attribution_free_wrapper(self, *args, **kwargs)
[10](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:10) was_hooked = True
[11](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:11) self.attribution_method.unhook()
---> [12](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:12) out = f(self, *args, **kwargs)
[13](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:13) if was_hooked:
[14](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/model_decorators.py:14) self.attribution_method.hook()
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72, in batched.<locals>.batched_wrapper(self, batch_size, *args, **kwargs)
[69](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:69) raise TypeError(f"Unsupported type {type(seq)} for batched attribution computation.")
[71](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:71) if batch_size is None:
---> [72](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:72) out = f(self, *args, **kwargs)
[73](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:73) return out if isinstance(out, list) else [out]
[74](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/attr/attribution_decorators.py:74) batched_args = [get_batched(batch_size, arg) for arg in args]
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:233, in HuggingfaceModel.generate(self, inputs, return_generation_output, skip_special_tokens, output_generated_only, **kwargs)
[231](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:231) inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens)
[232](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:232) inputs = inputs.to(self.device)
--> [233](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:233) generation_out = self.model.generate(
[234](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:234) inputs=inputs.input_ids,
[235](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:235) return_dict_in_generate=True,
[236](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:236) **kwargs,
[237](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:237) )
[238](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:238) sequences = generation_out.sequences
[239](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/inseq/models/huggingface_model.py:239) if output_generated_only and not self.is_encoder_decoder:
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
[1750](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1750) input_ids, model_kwargs = self._expand_inputs_for_generation(
[1751](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1751) input_ids=input_ids,
[1752](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1752) expand_size=generation_config.num_return_sequences,
[1753](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1753) is_encoder_decoder=self.config.is_encoder_decoder,
[1754](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1754) **model_kwargs,
[1755](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1755) )
[1757](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1757) # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> [1758](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758) result = self._sample(
[1759](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1759) input_ids,
[1760](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1760) logits_processor=prepared_logits_processor,
[1761](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1761) logits_warper=prepared_logits_warper,
[1762](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1762) stopping_criteria=prepared_stopping_criteria,
[1763](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1763) generation_config=generation_config,
[1764](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1764) synced_gpus=synced_gpus,
[1765](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1765) streamer=streamer,
[1766](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1766) **model_kwargs,
[1767](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1767) )
[1769](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1769) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
[1770](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1770) # 11. prepare logits warper
[1771](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1771) prepared_logits_warper = (
[1772](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1772) self._get_logits_warper(generation_config) if generation_config.do_sample else None
[1773](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:1773) )
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
[2394](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2394) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
[2396](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2396) # forward pass to get next token
-> [2397](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397) outputs = self(
[2398](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2398) **model_inputs,
[2399](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2399) return_dict=True,
[2400](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2400) output_attentions=output_attentions,
[2401](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2401) output_hidden_states=output_hidden_states,
[2402](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2402) )
[2404](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2404) if synced_gpus and this_peer_finished:
[2405](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/generation/utils.py:2405) continue # don't waste resources running the code we don't need
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
[1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544) result = None
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1139, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
[1136](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1136) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
[1138](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1138) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1139](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1139) outputs = self.model(
[1140](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1140) input_ids=input_ids,
[1141](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1141) attention_mask=attention_mask,
[1142](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1142) position_ids=position_ids,
[1143](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1143) past_key_values=past_key_values,
[1144](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1144) inputs_embeds=inputs_embeds,
[1145](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1145) use_cache=use_cache,
[1146](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1146) output_attentions=output_attentions,
[1147](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1147) output_hidden_states=output_hidden_states,
[1148](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1148) return_dict=return_dict,
[1149](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1149) )
[1151](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1151) hidden_states = outputs[0]
[1152](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1152) logits = self.lm_head(hidden_states)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
[1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544) result = None
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:968, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
[965](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:965) position_ids = position_ids.view(-1, seq_length).long()
[967](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:967) if inputs_embeds is None:
--> [968](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:968) inputs_embeds = self.embed_tokens(input_ids)
[970](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:970) if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
[971](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:971) is_padding_right = attention_mask[:, -1].sum().item() != batch_size
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1543) try:
[1544](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1544) result = None
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
[162](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162) def forward(self, input: Tensor) -> Tensor:
--> [163](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163) return F.embedding(
[164](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:164) input, self.weight, self.padding_idx, self.max_norm,
[165](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:165) self.norm_type, self.scale_grad_by_freq, self.sparse)
File /mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
[2258](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2258) # Note [embedding_renorm set_grad_enabled]
[2259](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2259) # XXX: equivalent to
[2260](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2260) # with torch.no_grad():
[2261](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2261) # torch.embedding_renorm_
[2262](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2262) # remove once script supports set_grad_enabled
[2263](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2263) _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> [2264](https://vscode-remote+tunnel-002binf-002dpod.vscode-resource.vscode-cdn.net/mnt/ceph_rbd/miniconda3/envs/inseq/lib/python3.10/site-packages/torch/nn/functional.py:2264) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
Agreed for balanced_low_0
, but I really doubt that it is an issue with Inseq. For device_map=auto
, the following code works for me on a 2x RTX A4000 16GB setup with pip install git+https://github.com/inseq-team/inseq.git@fix-multidevice
installed:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq
import logging
logging.basicConfig(level=logging.DEBUG)
# Load model and test data
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
test_data = load_dataset("xsum", split="test")
doc = test_data[1]['document']
input_prompt = f"Summarise the document below: {doc}"
messages = [{
"role": "user",
"content": input_prompt
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
attr_type="attention"
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(prompt, generation_args={"do_sample": False, "max_new_tokens": 100}, skip_special_tokens=True)
Hi,
It seems this bug is still there.
Heres the testing code
Error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
I also tried other way and ran into same error.
Please let me know how to fix this. Thanks.