Open jjsjann123 opened 1 month ago
Example moved from https://github.com/Lightning-AI/lightning-thunder/issues/262#issuecomment-2160763593
import thunder
import torch
def foo(flag):
if flag > 5:
return torch.ones(1)
return torch.zeros(1)
jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)
# Currently, output is incorrect.
print(jfoo(6)) # tensor([1.])
print(jfoo(0)) # tensor([1.])
# pro_trace = thunder.last_prologue_traces(jfoo)[-1]
# trace = thunder.last_traces(jfoo)[-1]
# print(pro_trace)
# print(trace)
we need to bake in NumberProxy used in control flow as a static in prologue trace.
Another example coming from https://github.com/Lightning-AI/lightning-thunder/pull/575#issuecomment-2161257731
import thunder
import torch
def foo(dev, idx):
return torch.ones(1, device=torch.device(dev, idx))
# jfoo = thunder.jit(foo) # works
jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)
print(jfoo("cuda", 0)) # fails as our current implementation can't handle `IntegerProxy`.
print(jfoo("cuda", 1))
# print(thunder.last_prologue_traces(jfoo)[-1])
# print(thunder.last_traces(jfoo)[-1])
brainstorm to myself:
I'm thinking about adding a flag to numberproxy to mark them as static_constraint.
i.e. for control flow, we need to insert those constraints on proxies as we see fit. As the example above with
if flag > 5:
Which translate to something like
5 0 LOAD_FAST 0 (flag)
2 LOAD_CONST 1 (5)
4 COMPARE_OP 4 (>)
https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/interpreter.py#L5082
We need to pop the stack. inspect for proxy and modify its static_constraint flag before wrapped it as a python value to bake in control flow.
So our idea was to make "evaluate to bool" the time when we record the constraint, i.e. for this, we would check if flag > 5
and record flag > 5
or not (flag > 5)
as a constraint. (Unfortunately, not (flag > 5)
is true more often than flag <= 5
due to NaN and might be overriden.
Note that it we absolutely need to implement mapping comparison ops to __lt__
and firends to jit through them. (And then you can use lookasides.)
Started #586 to fix the control flow stuff. it seems to have this example working now. https://github.com/Lightning-AI/lightning-thunder/issues/463#issuecomment-2161242845
linking comment: https://github.com/Lightning-AI/lightning-thunder/pull/451#issuecomment-2186631228 I think this is a case where we should temporarily fix it to be static shape.
🚀 Feature
prims operations should be able to insert contraints on prologue trace.
Motivation
Issue first brought up here: https://github.com/Lightning-AI/lightning-thunder/issues/262#issuecomment-2127729530
Looking at the computation trace here for a
slice_prim
:We notice that the output TensorProxy has a shape of
[8, 16, [IntegerProxy name=i17, value=2]]
, where the last dimension with sizei17
is not produced in the trace. This is because prim has arithmetic operation defined inside meta functions, as defined here: https://github.com/Lightning-AI/lightning-thunder/blob/b7154dc77b5bf8a02bcbc2e56310c62fd75808ff/thunder/core/prims.py#L3102Ideally we would want to expose that logic in the trace in order to support dynamic shape, we should lift the shape logic to clang instead of meta. (See note 1 in alternatives)
operations should be able to insert constraints by marking certain NumberProxy/TensorProxy dimensions as static shaped. e.g. axes args for a reduction operation could be a dynamic NumberProxy, which would instead want to bake in as static.
This would also allow us to temporarily resolve issue regarding prim operation which we have not yet added dynamic shape support.
In the slice_prim example above, we could have marked
i8
as a static and adding a constraint asi8 == 2
. This would allow us to have slice working without proper dynamic shape support. Which would be nice to keep #262 moving forward.Pitches
We can use
check_number_type_and_value
prim for the check. The challenge here is to propagate the proper check into prologue trace.Taking the slice example above, the operation would want to
check_number_type_and_value(i8, pyval(i8))
. But there's not guarantee that it's legit operation in prologue trace, which would require i8 to be a direct input. So we would want to projectstatic_constraints
fromi8
to its producers, which could beNumberProxy
as well asTensorProxy
's attributes.Alternatives
note 1: (Alternatively we can think about having logic to trace through prim_meta functions.)