ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
360 stars 34 forks source link

Can't run llama architectures #83

Closed Butanium closed 6 months ago

Butanium commented 6 months ago

I can't run nnsight on llama models. I get a runtime error RuntimeError: User specified an unsupported autocast device_type 'meta' MWE:

from nnsight import LanguageModel
model = LanguageModel('Maykeye/TinyLLama-v0',device_map='auto')
prompt = "The french translation for 'hello' is:\n"
with model.trace(prompt) as trace:
    pass

I tested:

Full stack trace:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-5-c8af8fe5fc7f>](https://localhost:8080/#) in <cell line: 4>()
      2 model = LanguageModel('Maykeye/TinyLLama-v0',device_map='cuda:0')
      3 prompts = "The french translation for 'hello' is:\n"
----> 4 with model.trace(prompts) as trace:
      5     pass

20 frames
[/usr/local/lib/python3.10/dist-packages/nnsight/models/NNsightModel.py](https://localhost:8080/#) in trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    194 
    195             # Otherwise open an invoker context with the give args.
--> 196             runner.invoke(*inputs, **invoker_args).__enter__()
    197 
    198         # If trace is False, you had to have provided an input.

[/usr/local/lib/python3.10/dist-packages/nnsight/contexts/Invoker.py](https://localhost:8080/#) in __enter__(self)
     65             ) as fake_mode:
     66                 with FakeCopyMode(fake_mode):
---> 67                     self.tracer._model._execute(
     68                         *copy.deepcopy(self.inputs),
     69                         **copy.deepcopy(self.tracer._kwargs),

[/usr/local/lib/python3.10/dist-packages/nnsight/models/mixins/Generation.py](https://localhost:8080/#) in _execute(self, prepared_inputs, generate, *args, **kwargs)
     19             return self._execute_generate(prepared_inputs, *args, **kwargs)
     20 
---> 21         return self._execute_forward(prepared_inputs, *args, **kwargs)
     22 
     23     def _scan(

[/usr/local/lib/python3.10/dist-packages/nnsight/models/LanguageModel.py](https://localhost:8080/#) in _execute_forward(self, prepared_inputs, *args, **kwargs)
    274         device = next(self._model.parameters()).device
    275 
--> 276         return self._model(
    277             *args,
    278             **prepared_inputs.to(device),

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1174 
   1175         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1176         outputs = self.model(
   1177             input_ids=input_ids,
   1178             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1017                 )
   1018             else:
-> 1019                 layer_outputs = decoder_layer(
   1020                     hidden_states,
   1021                     attention_mask=causal_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    738 
    739         # Self Attention
--> 740         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    741             hidden_states=hidden_states,
    742             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    359 
    360         past_key_value = getattr(self, "past_key_value", past_key_value)
--> 361         cos, sin = self.rotary_emb(value_states, position_ids)
    362         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    363 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, x, position_ids, seq_len)
    139         device_type = x.device.type
    140         device_type = device_type if isinstance(device_type, str) else "cpu"
--> 141         with torch.autocast(device_type=device_type, enabled=False):
    142             freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    143             emb = torch.cat((freqs, freqs), dim=-1)

[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in __init__(self, device_type, dtype, enabled, cache_enabled)
    239             self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
    240         else:
--> 241             raise RuntimeError(
    242                 f"User specified an unsupported autocast device_type '{self.device}'"
    243             )

RuntimeError: User specified an unsupported autocast device_type 'meta'
Butanium commented 6 months ago

Running the model with HF directly works:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0")
model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0", device_map="cuda")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))
arjunguha commented 6 months ago

Same error with Code Llama

cadentj commented 6 months ago

@arjunguha Can you try upgrading the transformers package? See the NDIF discord for more.

arjunguha commented 6 months ago

Yup fixed. Works with non-git transformers and nnight>0.2

Butanium commented 6 months ago

Another workaround suggested by Jaden: Just dispatch the model on init so its not on the 'meta' device:

from nnsight import LanguageModel
model = LanguageModel('Maykeye/TinyLLama-v0',device_map='auto', dispatch=True)
prompt = "The french translation for 'hello' is:\n"
with model.trace(prompt) as trace:
    pass
Butanium commented 6 months ago

Another workaround (for those who want to run the model remotely for example) is to do with model.trace(prompt, scan=False)

Disabling scan doesn't seems like a big deal :

scan: if to execute the model using FakeTensor in order to update the potential sizes/dtypes of all modules’ Envoys’ inputs/outputs as well as validate things work correctly. Scanning is not free computation wise so you may want to turn this to false when running in a loop. When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to True.