Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.08k stars 62 forks source link

Constraints to insert static numbers #463

Open jjsjann123 opened 1 month ago

jjsjann123 commented 1 month ago

🚀 Feature

prims operations should be able to insert contraints on prologue trace.

Motivation

Issue first brought up here: https://github.com/Lightning-AI/lightning-thunder/issues/262#issuecomment-2127729530

  1. problem arising from slice -> dynamic shape support?

Looking at the computation trace here for a slice_prim:

    # t18 = prims.slice_prim(a, [0, 0, 0], [8, 16, i8], [1, 1, 1])  # t18: "cuda:0 f32[8, 16, [IntegerProxy name=i17, value=2]]"

We notice that the output TensorProxy has a shape of [8, 16, [IntegerProxy name=i17, value=2]], where the last dimension with size i17 is not produced in the trace. This is because prim has arithmetic operation defined inside meta functions, as defined here: https://github.com/Lightning-AI/lightning-thunder/blob/b7154dc77b5bf8a02bcbc2e56310c62fd75808ff/thunder/core/prims.py#L3102

Ideally we would want to expose that logic in the trace in order to support dynamic shape, we should lift the shape logic to clang instead of meta. (See note 1 in alternatives)

  1. operations that would to insert static constraints

operations should be able to insert constraints by marking certain NumberProxy/TensorProxy dimensions as static shaped. e.g. axes args for a reduction operation could be a dynamic NumberProxy, which would instead want to bake in as static.

This would also allow us to temporarily resolve issue regarding prim operation which we have not yet added dynamic shape support.

In the slice_prim example above, we could have marked i8 as a static and adding a constraint as i8 == 2. This would allow us to have slice working without proper dynamic shape support. Which would be nice to keep #262 moving forward.

Pitches

We can use check_number_type_and_value prim for the check. The challenge here is to propagate the proper check into prologue trace.

Taking the slice example above, the operation would want to check_number_type_and_value(i8, pyval(i8)). But there's not guarantee that it's legit operation in prologue trace, which would require i8 to be a direct input. So we would want to project static_constraints from i8 to its producers, which could be NumberProxy as well as TensorProxy's attributes.

Alternatives

note 1: (Alternatively we can think about having logic to trace through prim_meta functions.)

jjsjann123 commented 3 weeks ago

Example moved from https://github.com/Lightning-AI/lightning-thunder/issues/262#issuecomment-2160763593

import thunder
import torch

def foo(flag):
    if flag > 5:
        return torch.ones(1)
    return torch.zeros(1)

jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)

# Currently, output is incorrect.
print(jfoo(6))  # tensor([1.])
print(jfoo(0))  # tensor([1.])

# pro_trace = thunder.last_prologue_traces(jfoo)[-1]
# trace = thunder.last_traces(jfoo)[-1]

# print(pro_trace)
# print(trace)

we need to bake in NumberProxy used in control flow as a static in prologue trace.

jjsjann123 commented 3 weeks ago

Another example coming from https://github.com/Lightning-AI/lightning-thunder/pull/575#issuecomment-2161257731

import thunder
import torch

def foo(dev, idx):
    return torch.ones(1, device=torch.device(dev, idx))

# jfoo = thunder.jit(foo)  # works
jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)
print(jfoo("cuda", 0))  # fails as our current implementation can't handle `IntegerProxy`.
print(jfoo("cuda", 1))

# print(thunder.last_prologue_traces(jfoo)[-1])
# print(thunder.last_traces(jfoo)[-1])
jjsjann123 commented 3 weeks ago

brainstorm to myself:

I'm thinking about adding a flag to numberproxy to mark them as static_constraint.

i.e. for control flow, we need to insert those constraints on proxies as we see fit. As the example above with

if flag > 5:

Which translate to something like

  5           0 LOAD_FAST                0 (flag)
              2 LOAD_CONST               1 (5)
              4 COMPARE_OP               4 (>)

https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/interpreter.py#L5082

We need to pop the stack. inspect for proxy and modify its static_constraint flag before wrapped it as a python value to bake in control flow.

t-vi commented 3 weeks ago

So our idea was to make "evaluate to bool" the time when we record the constraint, i.e. for this, we would check if flag > 5 and record flag > 5 or not (flag > 5) as a constraint. (Unfortunately, not (flag > 5) is true more often than flag <= 5 due to NaN and might be overriden. Note that it we absolutely need to implement mapping comparison ops to __lt__ and firends to jit through them. (And then you can use lookasides.)

jjsjann123 commented 3 weeks ago

Started #586 to fix the control flow stuff. it seems to have this example working now. https://github.com/Lightning-AI/lightning-thunder/issues/463#issuecomment-2161242845

jjsjann123 commented 1 week ago

linking comment: https://github.com/Lightning-AI/lightning-thunder/pull/451#issuecomment-2186631228 I think this is a case where we should temporarily fix it to be static shape.