ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
356 stars 34 forks source link

Issues with gemma-2b #185

Closed uSaiPrashanth closed 1 month ago

uSaiPrashanth commented 1 month ago

Description

Gemma 2b fails to work with nnsight. If you run the following script:

from nnsight import LanguageModel
import torch

model = LanguageModel("google/gemma-2b", device_map="cuda:1", torch_dtype=torch.bfloat16)
with model.trace("Hello World"):
    outputs = model.output.save()

You get the following error:

RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type

Versioning Info:

nnsight    0.2.21
transformers    4.42.4
torch    2.4.0
uSaiPrashanth commented 1 month ago

Full Stacktrace:

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 5
      2 import torch
      4 model = LanguageModel("google/gemma-2b", device_map="cuda:1", torch_dtype=torch.bfloat16)
----> 5 with model.trace("Hello World"):
      6     outputs = model.output.save()

File .../anaconda3/lib/python3.12/site-packages/nnsight/models/NNsightModel.py:194, in NNsight.trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    191         return output.value
    193     # Otherwise open an invoker context with the give args.
--> 194     runner.invoke(*inputs, **invoker_args).__enter__()
    196 # If trace is False, you had to have provided an input.
    197 if not trace:

File .../anaconda3/lib/python3.12/site-packages/nnsight/contexts/Invoker.py:69, in Invoker.__enter__(self)
     64     with FakeTensorMode(
     65         allow_non_fake_inputs=True,
     66         shape_env=ShapeEnv(assume_static_by_default=True),
     67     ) as fake_mode:
     68         with FakeCopyMode(fake_mode):
---> 69             self.tracer._model._execute(
     70                 *copy.deepcopy(self.inputs),
     71                 **copy.deepcopy(self.tracer._kwargs),
     72             )
     74     self.scanning = False
     76 else:

File .../anaconda3/lib/python3.12/site-packages/nnsight/models/mixins/Generation.py:21, in GenerationMixin._execute(self, prepared_inputs, generate, *args, **kwargs)
     17 if generate:
     19     return self._execute_generate(prepared_inputs, *args, **kwargs)
---> 21 return self._execute_forward(prepared_inputs, *args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/nnsight/models/LanguageModel.py:306, in LanguageModel._execute_forward(self, prepared_inputs, *args, **kwargs)
    302 def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):
    304     device = next(self._model.parameters()).device
--> 306     return self._model(
    307         *args,
    308         **prepared_inputs.to(device),
    309         **kwargs,
    310     )

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:1127, in GemmaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1124 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1126 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1127 outputs = self.model(
   1128     input_ids=input_ids,
   1129     attention_mask=attention_mask,
   1130     position_ids=position_ids,
   1131     past_key_values=past_key_values,
   1132     inputs_embeds=inputs_embeds,
   1133     use_cache=use_cache,
   1134     output_attentions=output_attentions,
   1135     output_hidden_states=output_hidden_states,
   1136     return_dict=return_dict,
   1137     cache_position=cache_position,
   1138 )
   1140 hidden_states = outputs[0]
   1141 logits = self.lm_head(hidden_states)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:931, in GemmaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    920     layer_outputs = self._gradient_checkpointing_func(
    921         decoder_layer.__call__,
    922         hidden_states,
   (...)
    928         cache_position,
    929     )
    930 else:
--> 931     layer_outputs = decoder_layer(
    932         hidden_states,
    933         attention_mask=causal_mask,
    934         position_ids=position_ids,
    935         past_key_value=past_key_values,
    936         output_attentions=output_attentions,
    937         use_cache=use_cache,
    938         cache_position=cache_position,
    939     )
    941 hidden_states = layer_outputs[0]
    943 if use_cache:

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:658, in GemmaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    655 hidden_states = self.input_layernorm(hidden_states)
    657 # Self Attention
--> 658 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    659     hidden_states=hidden_states,
    660     attention_mask=attention_mask,
    661     position_ids=position_ids,
    662     past_key_value=past_key_value,
    663     output_attentions=output_attentions,
    664     use_cache=use_cache,
    665     cache_position=cache_position,
    666 )
    667 hidden_states = residual + hidden_states
    669 # Fully Connected

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:562, in GemmaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    559 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    560 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
--> 562 cos, sin = self.rotary_emb(value_states, position_ids)
    563 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    565 if past_key_value is not None:
    566     # sin and cos are specific to RoPE models; cache_position needed for the static cache

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File .../anaconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File .../anaconda3/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py:113, in GemmaRotaryEmbedding.forward(self, x, position_ids, seq_len)
    111 device_type = x.device.type
    112 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
--> 113 with torch.autocast(device_type=device_type, enabled=False):
    114     freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    115     emb = torch.cat((freqs, freqs), dim=-1)

File .../anaconda3/lib/python3.12/site-packages/torch/amp/autocast_mode.py:341, in autocast.__enter__(self)
    338     return self
    340 self.prev_cache_enabled = torch.is_autocast_cache_enabled()
--> 341 self.prev = torch.is_autocast_enabled(self.device)
    342 self.prev_fastdtype = torch.get_autocast_dtype(self.device)
    343 torch.set_autocast_enabled(self.device, self._enabled)

RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type
JadenFiotto-Kaufman commented 1 month ago

@uSaiPrashanth This will be fixed in the 0.3 release (by the end of the month). For now, set scan and validate to False:

from nnsight import LanguageModel
import torch

model = LanguageModel("google/gemma-2b", device_map="cuda:1", torch_dtype=torch.bfloat16)
with model.trace("Hello World", scan=False, validate=False):
    outputs = model.output.save()