Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

math.xxx calls in function on NumberProxy is not being traced. #526

Open jjsjann123 opened 5 months ago

jjsjann123 commented 5 months ago

🐛 Bug

As per title suggests, with symbolic values cache, NumberProxy operation via math.xxx calls are not traced.

To Reproduce

import torch
import thunder
import math

def foo(a, b):
    return math.fmod(a, b)

a = 1.2
b = 2.0

jfoo = thunder.jit(foo, cache="symbolic values")
out = jfoo(a, b)

print("\n\tcompute last trace:\n", thunder.last_traces(jfoo)[-1])

Looking at the trace below, we do not have prims.fmod in trace but it just baked in the result and this is wrong.

        compute last trace:
 # Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  return 1.2
jjsjann123 commented 3 months ago

Seems to found another issue with basic indexing.

    def foo(a, size0, size1):
        return a[:, :, size0:size1]

    size0 = 8
    size1 = 16
    a = torch.randn(2, 4, 30, device="cuda")

    a_ref = a.detach()
    outs_ref = foo(a_ref, size0, size1)

    jfoo = thunder.jit(foo, cache="symbolic values")
    outs = jfoo(a, size0, size1)

Even though the trace is using size0 and size1, it's not taking them as inputs from prologue to compute.

def computation(a):
  # a: "cuda:0 f32[2, 4, 30]"

  # /volume/thunder_jit_reshape.py:52:          return a[:, :, size0:size1]
  t18 = ltorch.getitem(a, (slice(None, None, None), slice(None, None, None), slice([IntegerProxy name=i0, value=8, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=16, static=CONSTRAINT.CONSTRAINABLE], None)))  # t18: "cuda:0 f32[2, 4, 8]"
    # b2 = prims.lt(size0, 0)  # b2: "bool False"
    # b3 = prims.ge(size0, 0)  # b3: "bool True"
    # b4 = prims.lt(size1, 0)  # b4: "bool False"
    # b5 = prims.ge(size1, 0)  # b5: "bool True"
    # b6 = prims.gt(size0, size1)  # b6: "bool False"
    # b7 = prims.ge(size0, 30)  # b7: "bool False"
    # b8 = prims.ge(size1, 30)  # b8: "bool False"
    # b9 = prims.eq(size0, 0)  # b9: "bool False"
    # t18 = prims.slice_prim(a, [0, 0, size0], [2, 4, size1], [1, 1, 1])  # t18: "cuda:0 f32[2, 4, 8]"
  return t18