Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

Inf recursion when running HF BERT #805

Closed t-vi closed 1 month ago

t-vi commented 1 month ago

After #804 , when I'm running

import transformers, thunder, torch
@thunder.core.jit_ext.register_general_jit_lookaside(
    transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask
)
@thunder.core.jit_ext.interpreter_needs_wrap
def dummy(*args):
    pass

# Transformers 2.41+ adds some more non-essential data-dependent
# control flow behind a check whether we are compiling 
@thunder.core.jit_ext.register_general_jit_lookaside(
    torch._dynamo.is_compiling
)
@thunder.core.jit_ext.interpreter_needs_wrap
def dummy(*args):
    return True

m = transformers.BertForSequenceClassification(transformers.BertConfig())
inp = torch.randint(1, 20, (1, 32))
jm = thunder.jit(m)
jm(inp)

I get an inf recursion in the codeutils.to_printable:

File ~/data/firma/grid/thunder/lightning-thunder/thunder/core/codeutils.py:140, in to_printable(trace, x, import_ctx, object_ctx)
    138 printables = []
    139 for f in flat:
--> 140     printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
    142 printable = tree_unflatten(printables, spec)
    143 return printable

    [... skipping similar frames: to_printable at line 140 (2949 times)]

File ~/data/firma/grid/thunder/lightning-thunder/thunder/core/codeutils.py:140, in to_printable(trace, x, import_ctx, object_ctx)
    138 printables = []
    139 for f in flat:
--> 140     printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
    142 printable = tree_unflatten(printables, spec)
    143 return printable

File ~/data/firma/grid/thunder/lightning-thunder/thunder/core/codeutils.py:136, in to_printable(trace, x, import_ctx, object_ctx)
    133 if isinstance(x, ProxyInterface):
    134     return x
--> 136 if is_collection(x):
    137     flat, spec = tree_flatten(x)
    138     printables = []

File ~/data/firma/grid/thunder/lightning-thunder/thunder/core/baseutils.py:153, in is_collection(x)
    152 def is_collection(x: Any) -> bool:
--> 153     return isinstance(x, collections.abc.Collection) and not isinstance(x, (str, torch.Tensor, np.ndarray, torch.Size))

File /usr/lib/python3.11/abc.py:119, in ABCMeta.__instancecheck__(cls, instance)
    117 def __instancecheck__(cls, instance):
    118     """Override for isinstance(instance, cls)."""
--> 119     return _abc_instancecheck(cls, instance)

RecursionError: maximum recursion depth exceeded in comparison

Edit: Updated the repro after @k223kim 's hint below that something was missing. Thank you.

cc @apaz-cli

k223kim commented 1 month ago

Hey Tom! Let me work on this issue. It'll be great if you can assign me :) @t-vi

k223kim commented 1 month ago

However, I seem to get a different error when using the script that you have provided above. Would you be able to check the version of transformers? Cause this seems like an indexing issue to me.

Exception has occurred: NotImplementedError
exception: no description
  File "/Users/kaeunkim/lightning-thunder/thunder/core/proxies.py", line 1402, in __bool__
    raise NotImplementedError
  File "/Users/kaeunkim/lightning-thunder/thunder/core/jit_ext.py", line 704, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 1272, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 1362, in impl
    return dunder_bool(x)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 1347, in impl
    if element:
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 3686, in impl
    return any(v is tos1 or v == tos1 for v in tos)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/transformers/src/transformers/modeling_utils.py", line 4607, in warn_if_padding_and_no_attention_mask
    if self.config.pad_token_id in input_ids[:, [-1, 0]]:
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/transformers/src/transformers/models/bert/modeling_bert.py", line 1056, in forward
    self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/transformers/src/transformers/models/bert/modeling_bert.py", line 1695, in forward
    outputs = self.bert(
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6737, in fn_2
    return fn(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/interpreter.py", line 6769, in fn_
    raise e
  File "/Users/kaeunkim/lightning-thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
    result = jfn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/__init__.py", line 213, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
                                ^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/__init__.py", line 225, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/__init__.py", line 683, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kaeunkim/lightning-thunder/kaeun_test.py", line 5, in <module>
    jm(inp)
NotImplementedError: 
t-vi commented 1 month ago

Ahaha. Yeah, you need a noop-lookaside that I have not posted in the repro. Sorry about that. I'll update the repro.