Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

InterpreterError: Encountered exception TypeError: missing a required argument: 'value' while tracing #757

Open IvanYashchuk opened 1 month ago

IvanYashchuk commented 1 month ago

🐛 Bug

A minimal repro for the fixed issue (https://github.com/Lightning-AI/lightning-thunder/issues/461#issuecomment-2178023346) doesn't work anymore with a failure in Thunder's interpreter:

import transformers
import torch
import thunder

def fn(x):
    return transformers.modeling_outputs.BaseModelOutput(x)

jfn = thunder.jit(fn)

x = torch.randn(5, 5)

print(jfn(x))
TypeError: missing a required argument: 'value'

The above exception was the direct cause of the following exception:

InterpreterError                          Traceback (most recent call last)
Cell In[1], line 12
      8 jfn = thunder.jit(fn)
     10 x = torch.randn(5, 5)
---> 12 print(jfn(x))

File ~/dev/lightning-thunder/thunder/__init__.py:669, in jit.<locals>.fn_(*args, **kwargs)
    666 cs.last_trace_host_start = time.perf_counter_ns()
    667 cs.calls += 1
--> 669 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    670 cs.last_trace_host_execution_start = time.perf_counter_ns()
    672 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:223, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    221 tok = _cache_info_ctx.set({})
    222 try:
--> 223     res = fn(*args, **kwargs)
    224 finally:
    225     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:503, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    501 prologue_trc: TraceCtx
    502 computation_trc: TraceCtx
--> 503 jit_results: TraceResults = interpreter(
    504     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    505 )
    506 prologue_trc = jit_results.prologue_trace
    507 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/__init__.py:211, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
    202 def _general_frontend(
    203     fn: Callable,
    204     args: tuple[Any, ...],
   (...)
    209     sharp_edges: SHARP_EDGES_OPTIONS,
    210 ) -> TraceResults:
--> 211     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1743, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1741 with general_jit_ctx(ctx):
   1742     with tracectx(computation_trace):
-> 1743         result = jfn(*args, **kwargs)
   1744         prims.python_return(result)
   1745         computation_trace.set_current_source_location(None, None)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6686, in interpret.<locals>.fn_(*args, **kwargs)
   6682     traceback_str = os.linesep.join(f.format_with_source() for f in runtimectx.frame_stack)
   6683     msg = (
   6684         f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}"
   6685     )
-> 6686     raise InterpreterError(msg) from e
   6688 # NOTE: Wrapped functions are valid to assign new attributes to.
   6689 fn_._last_interpreter_log = runtimectx.interp_log  # type: ignore

InterpreterError: Encountered exception TypeError: missing a required argument: 'value' while tracing <function fn at 0x7f2e5c9be170>:

I used transformers-4.35.0.

t-vi commented 1 month ago

Note that #461 was about avoiding the inf recursion. This is from the dataclass decorator(?) setting __dataclass_params__ on the BaseModelOutput class (to _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False))

The trouble very likely stems from either _setattr_lookaside not properly handling assigning to classes and calling __setattr__ erroneously (see the dance that _getattr_lookaside does) or from the unbinding in the _call_dispatch going wrong) as the "missing a required argument" could well be that we miss "self", the object being assigned to.

t-vi commented 1 month ago

Note also that the above (the creation of the BaseModelOutput class) is triggered by the lazy loading of transformers leading to the importing of the modeling_outputs module being done by the interpreter. One option might be to deliberately not trace through the lazy importing but making it opaque.

t-vi commented 1 month ago

Here is an even more minimal repro (different error message due to different setattr method)

class A:
    pass

def fn(x):
    A.x = x

fn(1)  # works as expected
print(A.x)

jfn = thunder.jit(fn)
jfn(2) # fails because it calls the `__setattr__` intended for A-objects on the A-class.