Open IvanYashchuk opened 2 months ago
@IvanYashchuk - where do you hit this issue?
A tidy reproducible code is shared, as in the description. I confirmed that we can reproduce the error (with the slightly different KeyError message, with high probability)
@IvanYashchuk - where do you hit this issue?
I discovered this bug when trying to support PyTorch's and Dynamo's activation checkpointing implementation in https://github.com/Lightning-AI/lightning-thunder/pull/1127. Currently that PR works only for simple functions that have exactly the same implementation in PyTorch and Thunder (for example a.cos() + a.exp()
). Fixing this bug would enable supporting any PyTorch function.
@t-vi will open issues for detecting and raising a meaningful error message first.
Related:
Issues for the steps:
After 1220 is solved, we could use the present issue to track the remainder of the work.
Inside the tracing, it is not unlikely that some form of _interpret_call
can help you, see the torch.autograd.Function
-lookaside in JIT-ext for an advanced example.
🐛 Bug
I want to convert a Python function that might contain PyTorch calls into a Thunder function inside the lookaside function. I wasn't successful at using
thunder.core.interpreter.interpret
so I resorted tothunder_general_jit
. The inner functioninterpreted_fn
does the correct thing. However, something stands on the way of correct nested usage ofthunder_general_jit
and I see the following error:Script to reproduce it:
In general support for nested JIT-tracing for higher order operations is discussed in https://github.com/Lightning-AI/lightning-thunder/issues/1134.
cc @t-vi