Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 80 forks source link

Calling general_thunder_jit inside lookasides doesn't work #1126

Open IvanYashchuk opened 2 months ago

IvanYashchuk commented 2 months ago

🐛 Bug

I want to convert a Python function that might contain PyTorch calls into a Thunder function inside the lookaside function. I wasn't successful at using thunder.core.interpreter.interpret so I resorted to thunder_general_jit. The inner function interpreted_fn does the correct thing. However, something stands on the way of correct nested usage of thunder_general_jit and I see the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 60
     58 x = torch.randn(3, 4, requires_grad=True)
     59 jf = thunder.jit(f)
---> 60 out = jf(x)

File ~/dev/lightning-thunder/thunder/__init__.py:704, in jit.<locals>.fn_(*args, **kwargs)
    701 cs.last_trace_host_start = time.perf_counter_ns()
    702 cs.calls += 1
--> 704 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    705 cs.last_trace_host_execution_start = time.perf_counter_ns()
    707 if cache_entry.vanilla_tensor_args:

File ~/dev/lightning-thunder/thunder/core/langctxs.py:136, in langctx.__call__.<locals>._fn(*args, **kwargs)
    134 try:
    135     tok = set_langctx(self.langctx)
--> 136     result = fn(*args, **kwargs)
    137     return result
    138 finally:

File ~/dev/lightning-thunder/thunder/__init__.py:213, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    211 tok = _cache_info_ctx.set({})
    212 try:
--> 213     res = fn(*args, **kwargs)
    214 finally:
    215     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:500, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    498 prologue_trc: TraceCtx
    499 computation_trc: TraceCtx
--> 500 jit_results: TraceResults = thunder_general_jit(
    501     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    502 )
    503 prologue_trc = jit_results.prologue_trace
    504 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1562, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1559 else:
   1560     epilogue_trace = None
-> 1562 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs)
   1564 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
   1565 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1367, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs)
   1365 print(f"pro_to_comp_inps: {pro_to_comp_inps}")
   1366 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1367 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1369 with tracectx(prologue_trace):
   1370     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1367, in unpack_inputs.<locals>.<lambda>(x)
   1365 print(f"pro_to_comp_inps: {pro_to_comp_inps}")
   1366 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1367 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1369 with tracectx(prologue_trace):
   1370     for prim, *args in ctx._constraints:

KeyError: 140260996191264

Script to reproduce it:

from thunder.core.jit_ext import (
    compile_data_and_stats,
    CompileData,
    get_compile_data,
    interpreter_needs_wrap,
    SHARP_EDGES_OPTIONS,
    thunder_general_jit,
    unwrap,
    register_general_jit_lookaside,
    TraceResults,
    wrap_const,
)
import torch
import thunder

def my_call(fn, *args, **kwargs):
    return fn(*args, **kwargs)

@register_general_jit_lookaside(my_call)
def _lookaside(
    fn,
    *args,
    **kwargs,
):
    # Translate possibly PyTorch function into Thunder function
    def interpreted_fn(*args, **kwargs):
        # NOTE: Using thunder.jit with get_computation_and_inputs instead of
        # thunder_general_jit results in the same error
        unwrapped_function = unwrap(fn)
        cd = CompileData(
            fn=unwrapped_function,
            disable_preprocessing=True,
            executor_lookasides=get_compile_data().executor_lookasides,
        )
        with compile_data_and_stats(cd, None):
            jit_results: TraceResults = thunder_general_jit(
                unwrapped_function,
                args,
                kwargs,
                sharp_edges=SHARP_EDGES_OPTIONS.ALLOW,
            )
            # inps, pro_to_epi = jit_results.prologue_trace.python_callable()(*args, **kwargs)
            # result = jit_results.computation_trace.python_callable()(*inps)
            result = jit_results.computation_trace.python_callable()(*args)
            return result

    wrapped_thunder_function = wrap_const(interpreted_fn)
    result = interpreter_needs_wrap(my_call)(
        wrapped_thunder_function, *args, **kwargs
    )
    return result

def f(x):
    return my_call(lambda x: torch.sin(torch.cos(x)), x)

x = torch.randn(3, 4, requires_grad=True)
jf = thunder.jit(f)
out = jf(x)

In general support for nested JIT-tracing for higher order operations is discussed in https://github.com/Lightning-AI/lightning-thunder/issues/1134.

cc @t-vi

nvMelissa commented 1 month ago

@IvanYashchuk - where do you hit this issue?

crcrpar commented 1 month ago

A tidy reproducible code is shared, as in the description. I confirmed that we can reproduce the error (with the slightly different KeyError message, with high probability)

IvanYashchuk commented 1 month ago

@IvanYashchuk - where do you hit this issue?

I discovered this bug when trying to support PyTorch's and Dynamo's activation checkpointing implementation in https://github.com/Lightning-AI/lightning-thunder/pull/1127. Currently that PR works only for simple functions that have exactly the same implementation in PyTorch and Thunder (for example a.cos() + a.exp()). Fixing this bug would enable supporting any PyTorch function.

nvMelissa commented 1 month ago

@t-vi will open issues for detecting and raising a meaningful error message first.

t-vi commented 1 month ago

Related:

Issues for the steps:

After 1220 is solved, we could use the present issue to track the remainder of the work. Inside the tracing, it is not unlikely that some form of _interpret_call can help you, see the torch.autograd.Function-lookaside in JIT-ext for an advanced example.