cg123 / bitnet

Modeling code for a BitNet b1.58 Llama-style model.
MIT License
22 stars 2 forks source link

InternalTorchDynamoError #2

Closed geronimi73 closed 6 months ago

geronimi73 commented 6 months ago

Thank you for the implementation!

Have you come across this error? InternalTorchDynamoError: 'NoneType' object is not subscriptable

Code is a hello world basically:

from bitnet.configuration_bitllama import BitLlamaConfig
from bitnet.modeling_bitllama import BitLlamaForCausalLM
import torch 
tinyllama_config = {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "tie_word_embeddings": False,
  "use_cache": True,
  "vocab_size": 32000
}

model = BitLlamaForCausalLM(BitLlamaConfig(**tinyllama_config)).cuda()
model.generate(torch.tensor([[12]], device="cuda"))

Stacktrace

---------------------------------------------------------------------------
InternalTorchDynamoError                  Traceback (most recent call last)
Cell In[6], line 3
      1 import torch
      2 model.cuda()
----> 3 model.generate(torch.tensor([[12]], device="cuda"))

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:1544, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1526     return self.assisted_decoding(
   1527         input_ids,
   1528         candidate_generator=candidate_generator,
   (...)
   1540         **model_kwargs,
   1541     )
   1542 if generation_mode == GenerationMode.GREEDY_SEARCH:
   1543     # 11. run greedy search
-> 1544     return self.greedy_search(
   1545         input_ids,
   1546         logits_processor=prepared_logits_processor,
   1547         stopping_criteria=prepared_stopping_criteria,
   1548         pad_token_id=generation_config.pad_token_id,
   1549         eos_token_id=generation_config.eos_token_id,
   1550         output_scores=generation_config.output_scores,
   1551         output_logits=generation_config.output_logits,
   1552         return_dict_in_generate=generation_config.return_dict_in_generate,
   1553         synced_gpus=synced_gpus,
   1554         streamer=streamer,
   1555         **model_kwargs,
   1556     )
   1558 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   1559     if not model_kwargs["use_cache"]:

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:2404, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2401 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2403 # forward pass to get next token
-> 2404 outputs = self(
   2405     **model_inputs,
   2406     return_dict=True,
   2407     output_attentions=output_attentions,
   2408     output_hidden_states=output_hidden_states,
   2409 )
   2411 if synced_gpus and this_peer_finished:
   2412     continue  # don't waste resources running the code we don't need

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/bitnet_cg/bitnet/modeling_bitllama.py:537, in BitLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    532 return_dict = (
    533     return_dict if return_dict is not None else self.config.use_return_dict
    534 )
    536 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 537 outputs = self.model(
    538     input_ids=input_ids,
    539     attention_mask=attention_mask,
    540     position_ids=position_ids,
    541     past_key_values=past_key_values,
    542     inputs_embeds=inputs_embeds,
    543     use_cache=use_cache,
    544     output_attentions=output_attentions,
    545     output_hidden_states=output_hidden_states,
    546     return_dict=return_dict,
    547 )
    549 hidden_states = outputs[0]
    550 if self.config.pretraining_tp > 1:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/bitnet_cg/bitnet/modeling_bitllama.py:419, in BitLlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    408     layer_outputs = self._gradient_checkpointing_func(
    409         decoder_layer.__call__,
    410         hidden_states,
   (...)
    416         _layer_idx,
    417     )
    418 else:
--> 419     layer_outputs = decoder_layer(
    420         hidden_states,
    421         attention_mask=attention_mask,
    422         position_ids=position_ids,
    423         past_key_value=past_key_values,
    424         output_attentions=output_attentions,
    425         use_cache=use_cache,
    426         effective_idx=_layer_idx,
    427     )
    429 hidden_states = layer_outputs[0]
    431 if use_cache:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/bitnet_cg/bitnet/modeling_bitllama.py:197, in BitLlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, effective_idx, padding_mask)
    193 h0 = hidden_states
    195 hidden_states = self.input_layernorm(hidden_states)
--> 197 attention_output, self_attn_weights, present_key_value = self.self_attn(
    198     hidden_states=hidden_states,
    199     attention_mask=attention_mask,
    200     position_ids=position_ids,
    201     past_key_value=past_key_value,
    202     output_attentions=output_attentions,
    203     use_cache=use_cache,
    204     padding_mask=padding_mask,
    205 )
    206 hidden_states = h0 + attention_output
    208 if self.config.newton_steps:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:352, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    349     value_states = torch.cat(value_states, dim=-1)
    351 else:
--> 352     query_states = self.q_proj(hidden_states)
    353     key_states = self.k_proj(hidden_states)
    354     value_states = self.v_proj(hidden_states)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    487     dynamo_config_ctx.__enter__()
    488 try:
--> 489     return fn(*args, **kwargs)
    490 finally:
    491     set_eval_frame(prior)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:655, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_entry, frame_state)
    652             return hijacked_callback(frame, cache_entry, hooks, frame_state)
    654 with compile_lock, _disable_current_modes():
--> 655     return callback(frame, cache_entry, hooks, frame_state)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:727, in convert_frame.<locals>._convert_frame(frame, cache_entry, hooks, frame_state)
    725 counters["frames"]["total"] += 1
    726 try:
--> 727     result = inner_convert(frame, cache_entry, hooks, frame_state)
    728     counters["frames"]["ok"] += 1
    729     return result

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:383, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_entry, hooks, frame_state)
    370 signpost_event(
    371     "dynamo",
    372     "_convert_frame_assert._compile",
   (...)
    379     },
    380 )
    382 with config.patch(_patch_config_if_changed()):
--> 383     compiled_product = _compile(
    384         frame.f_code,
    385         frame.f_globals,
    386         frame.f_locals,
    387         frame.f_builtins,
    388         compiler_fn,
    389         one_graph,
    390         export,
    391         export_constraints,
    392         hooks,
    393         cache_size,
    394         frame,
    395         frame_state=frame_state,
    396         compile_id=compile_id,
    397     )
    398 return compiled_product

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:665, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id)
    663     fail_reason = str(e)
    664     exception_handler(e, code, frame, export=export)
--> 665     raise InternalTorchDynamoError(str(e)).with_traceback(
    666         e.__traceback__
    667     ) from None
    668 finally:
    669     from .utils import curr_frame

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:646, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id)
    644 with compile_context(CompileContext(compile_id)):
    645     try:
--> 646         guarded_code = compile_inner(code, one_graph, hooks, transform)
    647         return guarded_code
    648     except (
    649         Unsupported,
    650         TorchRuntimeError,
   (...)
    657         BisectValidationException,
    658     ) as e:

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:244, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    242 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    243     t0 = time.time()
--> 244     r = func(*args, **kwargs)
    245     time_spent = time.time() - t0
    246 compilation_time_metrics[key].append(time_spent)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:626, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    624 assert output.guards is not None
    625 CleanupManager.instance[out_code] = output.cleanups
--> 626 check_fn = CheckFunctionManager(
    627     output,
    628     hooks.guard_fail_fn if hooks else None,
    629 )
    631 guarded_code = GuardedCode(out_code, check_fn.check_fn)
    633 if not output.is_empty_graph() and hooks.guard_export_fn is not None:
    634     # We should not run the guard_export_fn when Dynamo does not
    635     # generate any graph. This can happen in export when TorchDynamo
    636     # generated bytecode has some reconstruction logic for mutated
    637     # variables which can trigger TorchDynamo on the children frames but
    638     # they are benign and do not generate any new graphs.

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py:1011, in CheckFunctionManager.__init__(self, output_graph, guard_fail_fn)
   1000     if (
   1001         not config.guard_nn_modules
   1002         and guard.is_nn_module()
   (...)
   1007         and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
   1008     ):
   1009         continue
-> 1011     guard.create(builder)
   1012 self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
   1013 self._weakrefs.clear()

File ~/.local/lib/python3.10/site-packages/torch/_guards.py:246, in Guard.create(self, builder)
    244 def create(self, builder: GuardBuilderBase):
    245     try:
--> 246         return self.create_fn(builder, self)
    247     except Exception:
    248         log.error("Error while creating guard:\n%s", str(self).rstrip())

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py:448, in GuardBuilder.CONSTANT_MATCH(self, guard)
    447 def CONSTANT_MATCH(self, guard: Guard):
--> 448     val = self.get(guard.name)
    449     if istype(val, (bool, type(None))):
    450         self.ID_MATCH(guard)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py:258, in GuardBuilder.get(self, name)
    257 def get(self, name: str) -> Any:
--> 258     return eval(name, self.scope, CLOSURE_VARS)

File <string>:1

InternalTorchDynamoError: 'NoneType' object is not subscriptable

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Packages

torch                             2.2.1
transformers                      4.38.2
geronimi73 commented 6 months ago

OK, it seems you are aware of this

https://github.com/cg123/bitnet/blob/025e14173224fbe36d69ad5688dc338c3c7f010e/distil.py#L318