Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 73 forks source link

Proxy renaming in general jit sometimes is skipped #946

Open IvanYashchuk opened 1 month ago

IvanYashchuk commented 1 month ago

🐛 Bug

Proxy renaming in the initial trace doesn't work sometimes. Let's check how does the initial trace look like for the following example (taken from test_core.py::test_cse):

import thunder
import torch

from thunder import clang

def func(x, y, device):
    a = x * y
    b = y / x
    c = x * y
    d = y / x
    z = a * b
    w = c * d 
    m = w * 1
    a = clang.uniform(w.shape, device=device, dtype=thunder.float16)
    return z, w, m, a

x = torch.randn(3, 4, device='cuda:0')
y = torch.randn(3, 4, device='cuda:0')
trace = thunder.trace()(func, x, y, 'cuda:0')
print(trace)
func = trace.python_callable()
jfunc = thunder.jit(func, executors=["torch"])
out = jfunc(x, y, device='cuda:0')
print(thunder.last_traces(jfunc)[0])

The initial trace in thunder.jit is

def computation(x, y):
  # x: "cuda:0 f32[3, 4]"
  # y: "cuda:0 f32[3, 4]"

  # thunder.func_39:15:       t0 = ltorch.mul(x, y)  # t0: "cuda:0 f32[3, 4]"
  t0 = ltorch.mul(x, y)  # t0: "cuda:0 f32[3, 4]"
    # t0 = prims.mul(x, y)  # t0: "cuda:0 f32[3, 4]"

  # thunder.func_39:16:       t1 = ltorch.true_divide(y, x)  # t1: "cuda:0 f32[3, 4]"
  t1 = ltorch.true_divide(y, x)  # t1: "cuda:0 f32[3, 4]"
    # t1 = prims.div(y, x)  # t1: "cuda:0 f32[3, 4]"

  # thunder.func_39:17:       t2 = ltorch.mul(x, y)  # t2: "cuda:0 f32[3, 4]"
  t2 = ltorch.mul(x, y)  # t2: "cuda:0 f32[3, 4]"
    # t2 = prims.mul(x, y)  # t2: "cuda:0 f32[3, 4]"

  # thunder.func_39:18:       t3 = ltorch.true_divide(y, x)  # t3: "cuda:0 f32[3, 4]"
  t3 = ltorch.true_divide(y, x)  # t3: "cuda:0 f32[3, 4]"
    # t3 = prims.div(y, x)  # t3: "cuda:0 f32[3, 4]"

  # thunder.func_39:19:       t4 = ltorch.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"
  t4 = ltorch.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"
    # t4 = prims.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"

  # thunder.func_39:20:       t5 = ltorch.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"
  t5 = ltorch.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"
    # t5 = prims.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"

  # thunder.func_39:21:       t6 = ltorch.mul(t5, 1)  # t6: "cuda:0 f32[3, 4]"
  t6 = ltorch.mul(t5, 1)  # t6: "cuda:0 f32[3, 4]"
    # _ = prims.convert_element_type(1, float)
    # t6 = prims.mul(t5, 1.0)  # t6: "cuda:0 f32[3, 4]"

  # thunder.func_39:22:       t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16)  # t7: "cuda:0 f16[3, 4]"
  t14 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16)  # t14: "cuda:0 f16[3, 4]"

  # /home/iyashchuk/dev/pytorch/main/torch/autograd/grad_mode.py:186:           torch._C._set_grad_enabled(mode)
  return (t4, t5, t6, t14)

Why is t14 not renamed to t7 and all other variables are renamed? The renaming is happening at https://github.com/Lightning-AI/lightning-thunder/blob/9f6e5b14e7a0fc6c96cca254540666d899df60b2/thunder/core/jit_ext.py#L1822-L1823

t-vi commented 1 month ago

The goal of the renaming has been to use more user variable names, not to get consecutive generic names. Is that what you want? (Also, with any bit of luck, #954 will change the trace above.

IvanYashchuk commented 1 month ago

I understand the goal, but in the example above the renaming doesn't happen when it should. In the example above the trace object returned by thunder.trace() plays the role of a user script and variable names should be preserved.

t-vi commented 1 month ago

I htink what happens often is that it tries to rename something to an internal name ("tos" etc.) and then finds that that is taken already for whatever reason. With #954 (which is not the end of the story), you'll get t7.

@no_autocast
def computation(x, y):
  # x: "cpu f32[3, 4]"
  # y: "cpu f32[3, 4]"

  # thunder.func_0:15:    t0 = ltorch.mul(x, y)  # t0: "cpu f32[3, 4]"
  t0 = ltorch.mul(x, y)  # t0: "cpu f32[3, 4]"
    # t0 = prims.mul(x, y)  # t0: "cpu f32[3, 4]"

  # thunder.func_0:16:    t1 = ltorch.true_divide(y, x)  # t1: "cpu f32[3, 4]"
  t1 = ltorch.true_divide(y, x)  # t1: "cpu f32[3, 4]"
    # t1 = prims.div(y, x)  # t1: "cpu f32[3, 4]"

  # thunder.func_0:17:    t2 = ltorch.mul(x, y)  # t2: "cpu f32[3, 4]"
  t2 = ltorch.mul(x, y)  # t2: "cpu f32[3, 4]"
    # t2 = prims.mul(x, y)  # t2: "cpu f32[3, 4]"

  # thunder.func_0:18:    t3 = ltorch.true_divide(y, x)  # t3: "cpu f32[3, 4]"
  t3 = ltorch.true_divide(y, x)  # t3: "cpu f32[3, 4]"
    # t3 = prims.div(y, x)  # t3: "cpu f32[3, 4]"

  # thunder.func_0:19:    t4 = ltorch.mul(t0, t1)  # t4: "cpu f32[3, 4]"
  t4 = ltorch.mul(t0, t1)  # t4: "cpu f32[3, 4]"
    # t4 = prims.mul(t0, t1)  # t4: "cpu f32[3, 4]"

  # thunder.func_0:20:    t5 = ltorch.mul(t2, t3)  # t5: "cpu f32[3, 4]"
  t5 = ltorch.mul(t2, t3)  # t5: "cpu f32[3, 4]"
    # t5 = prims.mul(t2, t3)  # t5: "cpu f32[3, 4]"

  # thunder.func_0:21:    t6 = ltorch.mul(t5, 1)  # t6: "cpu f32[3, 4]"
  t6 = ltorch.mul(t5, 1)  # t6: "cpu f32[3, 4]"
    # _ = prims.convert_element_type(1, float)
    # t6 = prims.mul(t5, 1.0)  # t6: "cpu f32[3, 4]"

  # thunder.func_0:22:    t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16)  # t7: "cpu f16[3, 4]"
  t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16)  # t7: "cpu f16[3, 4]"

  # /usr/local/lib/python3.12/dist-packages/torch/autograd/grad_mode.py:186:            torch._C._set_grad_enabled(mode)
  return (t4, t5, t6, t7)