ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
351 stars 34 forks source link

nnsight does not work with pytorch 2.4.0 when llama model is not dispatched #190

Closed Butanium closed 4 weeks ago

Butanium commented 1 month ago

Using nnsight 0.2.21 with pytorch 2.4.0 yields the following error trace:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-4-ddb8dfc9b7b4>](https://localhost:8080/#) in <cell line: 4>()
      2 model = LanguageModel('Maykeye/TinyLLama-v0',device_map='auto') # , dispatch=True)
      3 prompts = "The french translation for 'hello' is:\n"
----> 4 with model.trace(prompts) as trace:
      5     pass

20 frames
[/usr/local/lib/python3.10/dist-packages/nnsight/models/NNsightModel.py](https://localhost:8080/#) in trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    192 
    193             # Otherwise open an invoker context with the give args.
--> 194             runner.invoke(*inputs, **invoker_args).__enter__()
    195 
    196         # If trace is False, you had to have provided an input.

[/usr/local/lib/python3.10/dist-packages/nnsight/contexts/Invoker.py](https://localhost:8080/#) in __enter__(self)
     67             ) as fake_mode:
     68                 with FakeCopyMode(fake_mode):
---> 69                     self.tracer._model._execute(
     70                         *copy.deepcopy(self.inputs),
     71                         **copy.deepcopy(self.tracer._kwargs),

[/usr/local/lib/python3.10/dist-packages/nnsight/models/mixins/Generation.py](https://localhost:8080/#) in _execute(self, prepared_inputs, generate, *args, **kwargs)
     19             return self._execute_generate(prepared_inputs, *args, **kwargs)
     20 
---> 21         return self._execute_forward(prepared_inputs, *args, **kwargs)
     22 
     23     def _scan(

[/usr/local/lib/python3.10/dist-packages/nnsight/models/LanguageModel.py](https://localhost:8080/#) in _execute_forward(self, prepared_inputs, *args, **kwargs)
    304         device = next(self._model.parameters()).device
    305 
--> 306         return self._model(
    307             *args,
    308             **prepared_inputs.to(device),

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1601                 args = bw_hook.setup_input_hook(args)
   1602 
-> 1603             result = forward_call(*args, **kwargs)
   1604             if _global_forward_hooks or self._forward_hooks:
   1605                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1172 
   1173         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1174         outputs = self.model(
   1175             input_ids=input_ids,
   1176             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1601                 args = bw_hook.setup_input_hook(args)
   1602 
-> 1603             result = forward_call(*args, **kwargs)
   1604             if _global_forward_hooks or self._forward_hooks:
   1605                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    976                 )
    977             else:
--> 978                 layer_outputs = decoder_layer(
    979                     hidden_states,
    980                     attention_mask=causal_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1601                 args = bw_hook.setup_input_hook(args)
   1602 
-> 1603             result = forward_call(*args, **kwargs)
   1604             if _global_forward_hooks or self._forward_hooks:
   1605                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    716 
    717         # Self Attention
--> 718         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    719             hidden_states=hidden_states,
    720             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1601                 args = bw_hook.setup_input_hook(args)
   1602 
-> 1603             result = forward_call(*args, **kwargs)
   1604             if _global_forward_hooks or self._forward_hooks:
   1605                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    620         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    621 
--> 622         cos, sin = self.rotary_emb(value_states, position_ids)
    623         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    624 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1601                 args = bw_hook.setup_input_hook(args)
   1602 
-> 1603             result = forward_call(*args, **kwargs)
   1604             if _global_forward_hooks or self._forward_hooks:
   1605                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, x, position_ids)
    113         device_type = x.device.type
    114         device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
--> 115         with torch.autocast(device_type=device_type, enabled=False):
    116             freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    117             emb = torch.cat((freqs, freqs), dim=-1)

[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in __enter__(self)
    339 
    340         self.prev_cache_enabled = torch.is_autocast_cache_enabled()
--> 341         self.prev = torch.is_autocast_enabled(self.device)
    342         self.prev_fastdtype = torch.get_autocast_dtype(self.device)
    343         torch.set_autocast_enabled(self.device, self._enabled)

RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type

Code to reproduce:

!pip install nnsight "torch>=2.4"
from nnsight import LanguageModel
model = LanguageModel('gpt2',device_map='auto') # , dispatch=True)
prompts = "The french translation for 'hello' is:\n"
with model.trace(prompts) as trace:
    pass

Adding dispatch=True fixes the issue, specifying device_map='cpu' does not

AdamBelfki3 commented 4 weeks ago

@Butanium I was not able to reproduce this error

JadenFiotto-Kaufman commented 4 weeks ago

@Butanium This works / is fixed on the 0.3 branch will be released 8/23. Needed to update the patch for torch.amp.autocast handing 'meta' tensors