Closed uSaiPrashanth closed 1 month ago
Full Stacktrace:
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 5
2 import torch
4 model = LanguageModel("google/gemma-2b", device_map="cuda:1", torch_dtype=torch.bfloat16)
----> 5 with model.trace("Hello World"):
6 outputs = model.output.save()
File .../anaconda3/lib/python3.12/site-packages/nnsight/models/NNsightModel.py:194, in NNsight.trace(self, trace, invoker_args, scan, *inputs, **kwargs)
191 return output.value
193 # Otherwise open an invoker context with the give args.
--> 194 runner.invoke(*inputs, **invoker_args).__enter__()
196 # If trace is False, you had to have provided an input.
197 if not trace:
File .../anaconda3/lib/python3.12/site-packages/nnsight/contexts/Invoker.py:69, in Invoker.__enter__(self)
64 with FakeTensorMode(
65 allow_non_fake_inputs=True,
66 shape_env=ShapeEnv(assume_static_by_default=True),
67 ) as fake_mode:
68 with FakeCopyMode(fake_mode):
---> 69 self.tracer._model._execute(
70 *copy.deepcopy(self.inputs),
71 **copy.deepcopy(self.tracer._kwargs),
72 )
74 self.scanning = False
76 else:
File .../anaconda3/lib/python3.12/site-packages/nnsight/models/mixins/Generation.py:21, in GenerationMixin._execute(self, prepared_inputs, generate, *args, **kwargs)
17 if generate:
19 return self._execute_generate(prepared_inputs, *args, **kwargs)
---> 21 return self._execute_forward(prepared_inputs, *args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/nnsight/models/LanguageModel.py:306, in LanguageModel._execute_forward(self, prepared_inputs, *args, **kwargs)
302 def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):
304 device = next(self._model.parameters()).device
--> 306 return self._model(
307 *args,
308 **prepared_inputs.to(device),
309 **kwargs,
310 )
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:1127, in GemmaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1124 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1126 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1127 outputs = self.model(
1128 input_ids=input_ids,
1129 attention_mask=attention_mask,
1130 position_ids=position_ids,
1131 past_key_values=past_key_values,
1132 inputs_embeds=inputs_embeds,
1133 use_cache=use_cache,
1134 output_attentions=output_attentions,
1135 output_hidden_states=output_hidden_states,
1136 return_dict=return_dict,
1137 cache_position=cache_position,
1138 )
1140 hidden_states = outputs[0]
1141 logits = self.lm_head(hidden_states)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:931, in GemmaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
920 layer_outputs = self._gradient_checkpointing_func(
921 decoder_layer.__call__,
922 hidden_states,
(...)
928 cache_position,
929 )
930 else:
--> 931 layer_outputs = decoder_layer(
932 hidden_states,
933 attention_mask=causal_mask,
934 position_ids=position_ids,
935 past_key_value=past_key_values,
936 output_attentions=output_attentions,
937 use_cache=use_cache,
938 cache_position=cache_position,
939 )
941 hidden_states = layer_outputs[0]
943 if use_cache:
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:658, in GemmaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
655 hidden_states = self.input_layernorm(hidden_states)
657 # Self Attention
--> 658 hidden_states, self_attn_weights, present_key_value = self.self_attn(
659 hidden_states=hidden_states,
660 attention_mask=attention_mask,
661 position_ids=position_ids,
662 past_key_value=past_key_value,
663 output_attentions=output_attentions,
664 use_cache=use_cache,
665 cache_position=cache_position,
666 )
667 hidden_states = residual + hidden_states
669 # Fully Connected
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:562, in GemmaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
559 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
560 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
--> 562 cos, sin = self.rotary_emb(value_states, position_ids)
563 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
565 if past_key_value is not None:
566 # sin and cos are specific to RoPE models; cache_position needed for the static cache
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File .../anaconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:113, in GemmaRotaryEmbedding.forward(self, x, position_ids, seq_len)
111 device_type = x.device.type
112 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
--> 113 with torch.autocast(device_type=device_type, enabled=False):
114 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
115 emb = torch.cat((freqs, freqs), dim=-1)
File .../anaconda3/lib/python3.12/site-packages/torch/amp/autocast_mode.py:341, in autocast.__enter__(self)
338 return self
340 self.prev_cache_enabled = torch.is_autocast_cache_enabled()
--> 341 self.prev = torch.is_autocast_enabled(self.device)
342 self.prev_fastdtype = torch.get_autocast_dtype(self.device)
343 torch.set_autocast_enabled(self.device, self._enabled)
RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type
@uSaiPrashanth This will be fixed in the 0.3 release (by the end of the month). For now, set scan and validate to False:
from nnsight import LanguageModel
import torch
model = LanguageModel("google/gemma-2b", device_map="cuda:1", torch_dtype=torch.bfloat16)
with model.trace("Hello World", scan=False, validate=False):
outputs = model.output.save()
Description
Gemma 2b fails to work with nnsight. If you run the following script:
You get the following error:
Versioning Info: