Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

Remove all occurances of thunder.compile and TestExecutor.make_callable_legacy #198

Closed IvanYashchuk closed 1 week ago

IvanYashchuk commented 5 months ago

🐛 Bug

All thunder.compile usage in the repo should be replaced with equivalent thunder.jit.

Two places in the test framework use thunder.compile today, the tests should be rewritten to use thunder.jit instead. https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/tests/framework.py#L134-L137 https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/tests/framework.py#L152-L157

t-vi commented 5 months ago

Thank you! So one prominent pattern is make_callable_legacy(vjp(fn)). I could be wrong, but I think that this will likely need work signature changing transforms / updating of prologues to make something like vjp(jit(fn)) possible (which might be easier than jit(vjp(fn))).

IvanYashchuk commented 5 months ago

I think jit(vjp(fn), disable_torch_autograd=True) should work. JIT sees only vjp(fn) and not fn itself, so there's no change to the prologue happening.

kiya00 commented 4 months ago

jit(vjp(fn), disable_torch_autograd=True) doesn't always work. for example, with the patch below, run pytest test_grad.py -vs -k test_vjp_correctness_abs_nvfuser_cuda_float64, will hit NotImplementedError, we get an OPAQUE whose first input is not PseudoInst.CONSTANT, it throws error in https://github.com/Lightning-AI/lightning-thunder/blob/e0ab64867a5be914d0548c195a3f850a76c8c397/thunder/core/jit_ext.py#L1279

the `PseudoInst.OPAQUE` that throws error ``` ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='__globals__', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='flatten_func', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='__globals__', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='tree_flatten', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='func', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='__globals__', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='_C', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='flatten', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='', ext_flag=1), ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='__globals__', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='flatten_func', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='__globals__', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='tree_flatten', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='keywords', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value='none_is_leaf', ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=1), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value=None, ext_flag=1), ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value={}, ext_flag=1), ProvenanceRecord(inst=, inputs=[ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value=0, ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0)], output_idx=0, output_key=None, value=None, ext_flag=0)], output_idx=0, output_key=None, value=None, ext_flag=0), ProvenanceRecord(inst=, inputs=[], output_idx=0, output_key=None, value={}, ext_flag=1)], output_idx=0, output_key=None, value=None, ext_flag=0) ```
error msg ``` NotImplementedError: Exception occured unpacking object from ProvenanceRecord( E i1 = INPUT_FN() E i2 = LOAD_ATTR(i1, '__globals__') E i3 = BINARY_SUBSCR(i2, 'flatten_func') E i4 = LOAD_ATTR(i3, '__globals__') E i5 = BINARY_SUBSCR(i4, 'tree_flatten') E i6 = LOAD_ATTR(i5, 'func') E i7 = LOAD_ATTR(i6, '__globals__') E i8 = BINARY_SUBSCR(i7, '_C') E i9 = LOAD_ATTR(i8, 'flatten') E i10 = LOAD_ATTR(i5, 'keywords') E i11 = BINARY_SUBSCR(i10, 'none_is_leaf') E i12 = INPUT_ARGS() E i13 = BINARY_SUBSCR(i12, 0) E i14 = BUILD_TUPLE(CONSTANT({}), i13) E i15 = BUILD_TUPLE('', i11, None, i14) E i16 = OPAQUE(i9, i15, CONSTANT({})) E i17 = BINARY_SUBSCR(i16, 0) E i18 = BINARY_SUBSCR(i17, 0) E ) ```

the patch:

diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py
index 7ccfa5e..1740a30 100644
--- a/thunder/tests/test_grad.py
+++ b/thunder/tests/test_grad.py
@@ -301,12 +301,12 @@ def check_vjp(f, *primals, executor="torch", atol=1e-5, rtol=1.3e-6):
     make = partial(make_tensor_like, low=0, high=1)

     u = tree_map(make, primals)
-    outs_p, J_u = numerical_jvp(executor.make_callable_legacy(f, disable_torch_autograd_support=True))(primals, u)
+    outs_p, J_u = numerical_jvp(executor.make_callable(f, disable_torch_autograd=True))(primals, u)

     multiple_results = isinstance(outs_p, Sequence)

     v = tree_map(make, outs_p)
-    _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)
+    _, J_star_v = executor.make_callable(vjp(f), disable_torch_autograd=True)(primals, v)

     if not multiple_results:
         v = (v,)

I'm not quite sure how to fix this @t-vi @IvanYashchuk , and there are quite a few cases in test_grad.py that I guess fail for the same reason

kiya00 commented 4 months ago

There are some cases I don't know how to fix, I'm going to get the specific case to reproduce in a separate issue and get some help from JIT experts.

t-vi commented 4 months ago

We do have a functionality gap for signature-changing transforms as they need adapting of the prologue.

mruberry commented 4 months ago

triage review — to be addressed in design review for distributed transforms

IvanYashchuk commented 4 months ago

I don't understand how it's related to distributed transforms or signature-changing transforms or prologue. The main problem is that thunder.jit cannot deal with the pattern of how autocast/vmap/etc are implemented building a side trace and reinterpreting the trace.

IvanYashchuk commented 4 months ago

Let's try moving first autocast tests from thunder.compile to thunder.jit:

diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py
index 644179a9..979ec018 100644
--- a/thunder/tests/test_autocast.py
+++ b/thunder/tests/test_autocast.py
@@ -53,7 +53,7 @@ def test_thunder_autocast_transform(executor, device, dtype):
     ):
         autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype)
         x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3))
-        compiled = executor.make_callable_legacy(autocast(func, dtype=autocast_dtype))
+        compiled = executor.make_callable(autocast(func, dtype=autocast_dtype), disable_torch_autograd=True)
         out = compiled(x, y, z)

         devicetype = torch.device(device).type

this leads to a weird error:

RuntimeError: Expected a.dtype=float32 and b.dtype=bfloat16 to be the same

raised from thunder/core/prims.py:3512: in matmul_meta. What's the trace for which it occurs?

@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  # t_2: "cuda:0 f32[2, 2]"
  tree = prims.convert_element_type(a, dtypes.bfloat16)  # tree: "cuda:0 f32[2, 2]"
  t2 = prims.convert_element_type(b, dtypes.bfloat16)  # t2: "cuda:0 bf16[2, 2]"
  result = prims.matmul(tree, t2)  # result: "cuda:0 bf16[2, 2]"
  t5 = ltorch.add(result, t_2, alpha=None)  # t5: "cuda:0 f32[2, 2]"
    # t4 = prims.convert_element_type(result, dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.add(t4, t_2)  # t5: "cuda:0 f32[2, 2]"
  return t5

tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]" line is very suspicious. Why is the output named "tree"? Why is the output type fp32 even though the operation is casting it to bf16? There are no problems in the autocast transform itself, which can be verified with:

diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index bc70ae3f..791f157f 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -4026,6 +4026,13 @@ def autocast(func: Callable, dtype: dtypes.dtype):
     @wraps(func)
     def wrapper(*args, **kwargs):
         trace = construct_trace()(func, *args, **kwargs)
+
+        def fn(*args, **kwargs):
+            return eval_trace(trace, *args, **kwargs, symbol_mapper=partial(autocast_symbol_mapper, dtype=dtype))
+
+        test_trace = construct_trace()(fn, *args, **kwargs)
+        print(test_trace)
+
         return eval_trace(trace, *args, **kwargs, symbol_mapper=partial(autocast_symbol_mapper, dtype=dtype))

would show us the trace:

@torch.no_grad()
@no_autocast
def fn(*args):
  # args: "Collection"
  t0, t1, t2, = args
  t3 = prims.convert_element_type(t0, dtypes.bfloat16)  # t3: "cuda:0 bf16[2, 2]"
  t4 = prims.convert_element_type(t1, dtypes.bfloat16)  # t4: "cuda:0 bf16[2, 2]"
  t5 = prims.matmul(t3, t4)  # t5: "cuda:0 bf16[2, 2]"
  t7 = ltorch.add(t5, t2, alpha=None)  # t7: "cuda:0 f32[2, 2]"
    # t6 = prims.convert_element_type(t5, dtypes.float32)  # t6: "cuda:0 f32[2, 2]"
    # t7 = prims.add(t6, t2)  # t7: "cuda:0 f32[2, 2]"
  return t7

all is good with this trace. Let's jump into the jit interpreter code now. Inspecting the state right before and after this line: https://github.com/Lightning-AI/lightning-thunder/blob/67be468326b49b0f5c0a11df90ff21544d042ac8/thunder/core/jit_ext.py#L1579 reveals the problem. Here's the trace right before L1579:

@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2):
  # t_0: "cuda:0 f32[2, 2]"
  # t_1: "cuda:0 f32[2, 2]"
  # t_2: "cuda:0 f32[2, 2]"
  t1 = prims.convert_element_type(t_0, dtypes.bfloat16)  # t1: "cuda:0 bf16[2, 2]"
  t2 = prims.convert_element_type(t_1, dtypes.bfloat16)  # t2: "cuda:0 bf16[2, 2]"
  t3 = prims.matmul(t1, t2)  # t3: "cuda:0 bf16[2, 2]"
  t5 = ltorch.add(t3, t_2, alpha=None)  # t5: "cuda:0 f32[2, 2]"
    # t4 = prims.convert_element_type(t3, dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.add(t4, t_2)  # t5: "cuda:0 f32[2, 2]"
  return t5

and here's one right after that line:

@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  # t_2: "cuda:0 f32[2, 2]"
  tree = prims.convert_element_type(a, dtypes.bfloat16)  # tree: "cuda:0 f32[2, 2]"
  t2 = prims.convert_element_type(b, dtypes.bfloat16)  # t2: "cuda:0 bf16[2, 2]"
  result = prims.matmul(tree, t2)  # result: "cuda:0 bf16[2, 2]"
  t5 = ltorch.add(result, t_2, alpha=None)  # t5: "cuda:0 f32[2, 2]"
    # t4 = prims.convert_element_type(result, dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.add(t4, t_2)  # t5: "cuda:0 f32[2, 2]"
  return t5

@t-vi, what's happening at that line? Just commenting this line out makes the test pass.

kiya00 commented 4 months ago

There's a similar issue here: https://github.com/Lightning-AI/lightning-thunder/issues/283 cc: @IvanYashchuk @t-vi

IvanYashchuk commented 4 months ago

I missed that this problem described in the post above was already reported in https://github.com/Lightning-AI/lightning-thunder/issues/283

IvanYashchuk commented 1 month ago

Here's a diff that should remove thunder.compile usage in one test:

diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py
index e017ec20..19eb1847 100644
--- a/thunder/tests/test_nvfuser_remat.py
+++ b/thunder/tests/test_nvfuser_remat.py
@@ -62,7 +62,9 @@ def disable_rematerialization_in_nvfuser_fusion(func):
 def test_find_producer_symbols(executor, device, _):
     # We will try to find a subgraph for rematerializing __c and __d
     t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
-    compiled_func = thunder.compile(func, disable_preprocessing=True)
+    initial_trace = thunder.trace()(func, t0)
+    compiled_func = thunder.jit(initial_trace.python_callable())
+    # compiled_func = thunder.compile(func, disable_preprocessing=True)
     _ = compiled_func(t0)
     traces = thunder.last_traces(compiled_func)
     trace = traces[-1]

but it's currently blocked by https://github.com/Lightning-AI/lightning-thunder/issues/831.

t-vi commented 1 week ago

closed by #1114 built on #837