Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.08k stars 62 forks source link

Dynamic constraints and NumberProxies #262

Open jjsjann123 opened 2 months ago

jjsjann123 commented 2 months ago

🚀 Feature

We'd like to have thunder.jit support dynamic constraints so we are not going to bake every number in a program as compile time constant. This should allow us re-use some compiled program to avoid endless recompilation with dynamic shape.

Pitch

We'd want to:

  1. keep Scalar input (including tensor properties like shape) as NumberProxy in the trace and feed them into compute trace as inputs.
  2. have prologue trace to insert checks to ensure program safety.

prototyping PRs:

Issues:

~number proxy is no number: #272~ executor specific caching rule: #263 prim should be able to insert constraints to bake in static numbers: #463 dynamic shape needs to be modeled in trace: #471 (to be opened) utils.check on NumberProxy needs to be sanitized in prim/clang/torch. NumberProxy handling in grad transform: #541 (will be addressed in #244) NumberProxy inconsistency is introduced by grad transform #541

Progress

~NumberProxy inheritance PR merged #286.~

Currently working on enabling caching option symbolic values to allow trace handling dynamic scalar inputs (WIP in PR #250). This is going to be the next milestone. A simple prototype is working where NumberProxy is used to represent a dynamic scalar input to a trace as a number operand. I tried to test water with setting symbolic values as the default cache option and CI exploded. I'm still going through all the failures. Aside from minor code logic patches here and there (since NumberProxy isn't widely used in existing code base), one of the main challenge I'm seeing right now is the accidental exposure of dynamic shape (see issue #471 #463). Ideally we should resolve #471 and model dynamic shape in our trace. That's going to be a longer endeavor to pull through and I should get some help when we decide our plan and start working on that. Meanwhile, I think in the short term we could push for #463, where we'll just bake in static shape & constraints for NumberProxy to avoid the dynamic shape from dynamic scalar input. I need to further evaluate this solution to figure out if it's enough. Nevertheless, it's still a feature that we might want to have for nvfuser integration. i.e. reduction axes needs to be baked in anyway.

jjsjann123 commented 1 month ago

As my current prototype in #250, I realized that torch/__init__.py has a few complexities:

  1. value based control flow: https://github.com/Lightning-AI/lightning-thunder/blob/82185e3a55d5b3f0bea8a7366d74a275dbe34acd/thunder/torch/__init__.py#L953
  2. there's also semantic checks: https://github.com/Lightning-AI/lightning-thunder/blob/82185e3a55d5b3f0bea8a7366d74a275dbe34acd/thunder/torch/__init__.py#L752-L753

[AIs] We need to allow language-based constraints to be injected and placed in prologue trace.

Right now I'm seeing some of those checks showing up in the compute trace and that could be excessive, but we should be able to const fold and merge some of those checks as an optimization later.

jjsjann123 commented 1 month ago

unintentionally walked into dynamic shapes case.

https://github.com/Lightning-AI/lightning-thunder/blob/b7154dc77b5bf8a02bcbc2e56310c62fd75808ff/thunder/core/prims.py#L3102

operations inside prims is not traced, so we won't be able to actually have new_shape the way it's currently done with NumberProxies. Because the new number proxy won't be visible... The shape computation should be done at clang instead.

But for now we can just mark everything as constant shapes.

mark down the concrete example here:

def foo(a, dim0, dim1):
    # there's an assert with this.
    dim2 = a.size(2) - dim0 - dim1
    x = torch.split(a, (dim0, dim1, dim2), dim=2)

    return x
    # code below triggers an assert.
    #dim3 = x[2].size(2) + dim2
    #return x, dim3

In the slice operation, you can see that prims.slice_prim gives us output tensor with dynamic shape. The issue here is that the new shape (e.g. t18 [8, 16, [IntegerProxy name=i17, value=2]]) is not showing up in the trace on how to compute it (e.g. i17 is a magic number that pops up from nowhere).

@torch.no_grad()
@no_autocast
def computation(a, dim0, dim1):
  # a: "cuda:0 f32[8, 16, 32]"
  # dim0: "int 2"
  # dim1: "int 8"
  result = operator.sub(32, dim0)  # result: "int 30"
    # result = prims.sub(32, dim0)  # result: "int 30"
  dim2 = operator.sub(result, dim1)  # dim2: "int 22"
    # dim2 = prims.sub(result, dim1)  # dim2: "int 22"
  del result
  [t18, t33, t45] = nvFusion0(a, dim0, dim1, dim2)
    # i8 = prims.add(0, dim0)  # i8: "int 2"
    # i9 = prims.add(i8, dim1)  # i9: "int 10"
    # t18 = prims.slice_prim(a, [0, 0, 0], [8, 16, i8], [1, 1, 1])  # t18: "cuda:0 f32[8, 16, [IntegerProxy name=i17, value=2]]"
    # t33 = prims.slice_prim(a, [0, 0, i8], [8, 16, i9], [1, 1, 1])  # t33: "cuda:0 f32[8, 16, [IntegerProxy name=i32, value=8]]"
    # t45 = prims.slice_prim(a, [0, 0, i9], [8, 16, 32], [1, 1, 1])  # t45: "cuda:0 f32[8, 16, [IntegerProxy name=i44, value=22]]"
  del a, dim0, dim1, dim2
  return (t18, t33, t45)
jjsjann123 commented 1 month ago

Note for myself.

Some tricky shape usage in the trace that leads to control flow. https://github.com/Lightning-AI/lightning-thunder/pull/260#discussion_r1628259774

kshitij12345 commented 3 weeks ago

Will CACHE_OPTIONS.SYMBOLIC_VALUES be able to deal with following snippet? If yes, should this be a seperate issue for tracking (as it produces incorrect output currently)?

import thunder
import torch

def foo(flag):
    if flag > 5:
        return torch.ones(1)
    return torch.zeros(1)

jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)

# Currently, output is incorrect.
print(jfoo(6))  # tensor([1.])
print(jfoo(0))  # tensor([1.])

# pro_trace = thunder.last_prologue_traces(jfoo)[-1]
# trace = thunder.last_traces(jfoo)[-1]

# print(pro_trace)
# print(trace)

Currently, prologue trace doesn't have any check for input. This works fine with the default cache option of CONSTANT_VALUES.

jjsjann123 commented 3 weeks ago

Thanks for bringing up that. Yes it should have been able to do that. I do have another issue tracking this specific problem. #463

Let me move the example there.