Closed IvanYashchuk closed 1 week ago
Thank you! So one prominent pattern is make_callable_legacy(vjp(fn))
. I could be wrong, but I think that this will likely need work signature changing transforms / updating of prologues to make something like vjp(jit(fn))
possible (which might be easier than jit(vjp(fn))
).
I think jit(vjp(fn), disable_torch_autograd=True)
should work. JIT sees only vjp(fn)
and not fn
itself, so there's no change to the prologue happening.
jit(vjp(fn), disable_torch_autograd=True)
doesn't always work.
for example, with the patch below, run pytest test_grad.py -vs -k test_vjp_correctness_abs_nvfuser_cuda_float64
, will hit NotImplementedError
, we get an OPAQUE
whose first input is not PseudoInst.CONSTANT
, it throws error in https://github.com/Lightning-AI/lightning-thunder/blob/e0ab64867a5be914d0548c195a3f850a76c8c397/thunder/core/jit_ext.py#L1279
the patch:
diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py
index 7ccfa5e..1740a30 100644
--- a/thunder/tests/test_grad.py
+++ b/thunder/tests/test_grad.py
@@ -301,12 +301,12 @@ def check_vjp(f, *primals, executor="torch", atol=1e-5, rtol=1.3e-6):
make = partial(make_tensor_like, low=0, high=1)
u = tree_map(make, primals)
- outs_p, J_u = numerical_jvp(executor.make_callable_legacy(f, disable_torch_autograd_support=True))(primals, u)
+ outs_p, J_u = numerical_jvp(executor.make_callable(f, disable_torch_autograd=True))(primals, u)
multiple_results = isinstance(outs_p, Sequence)
v = tree_map(make, outs_p)
- _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)
+ _, J_star_v = executor.make_callable(vjp(f), disable_torch_autograd=True)(primals, v)
if not multiple_results:
v = (v,)
I'm not quite sure how to fix this @t-vi @IvanYashchuk , and there are quite a few cases in test_grad.py
that I guess fail for the same reason
There are some cases I don't know how to fix, I'm going to get the specific case to reproduce in a separate issue and get some help from JIT experts.
We do have a functionality gap for signature-changing transforms as they need adapting of the prologue.
triage review — to be addressed in design review for distributed transforms
I don't understand how it's related to distributed transforms or signature-changing transforms or prologue. The main problem is that thunder.jit
cannot deal with the pattern of how autocast/vmap/etc are implemented building a side trace and reinterpreting the trace.
Let's try moving first autocast tests from thunder.compile
to thunder.jit
:
diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py
index 644179a9..979ec018 100644
--- a/thunder/tests/test_autocast.py
+++ b/thunder/tests/test_autocast.py
@@ -53,7 +53,7 @@ def test_thunder_autocast_transform(executor, device, dtype):
):
autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype)
x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3))
- compiled = executor.make_callable_legacy(autocast(func, dtype=autocast_dtype))
+ compiled = executor.make_callable(autocast(func, dtype=autocast_dtype), disable_torch_autograd=True)
out = compiled(x, y, z)
devicetype = torch.device(device).type
this leads to a weird error:
RuntimeError: Expected a.dtype=float32 and b.dtype=bfloat16 to be the same
raised from thunder/core/prims.py:3512: in matmul_meta
. What's the trace for which it occurs?
@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
# t_2: "cuda:0 f32[2, 2]"
tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]"
t2 = prims.convert_element_type(b, dtypes.bfloat16) # t2: "cuda:0 bf16[2, 2]"
result = prims.matmul(tree, t2) # result: "cuda:0 bf16[2, 2]"
t5 = ltorch.add(result, t_2, alpha=None) # t5: "cuda:0 f32[2, 2]"
# t4 = prims.convert_element_type(result, dtypes.float32) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.add(t4, t_2) # t5: "cuda:0 f32[2, 2]"
return t5
tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]"
line is very suspicious. Why is the output named "tree"? Why is the output type fp32 even though the operation is casting it to bf16?
There are no problems in the autocast transform itself, which can be verified with:
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index bc70ae3f..791f157f 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -4026,6 +4026,13 @@ def autocast(func: Callable, dtype: dtypes.dtype):
@wraps(func)
def wrapper(*args, **kwargs):
trace = construct_trace()(func, *args, **kwargs)
+
+ def fn(*args, **kwargs):
+ return eval_trace(trace, *args, **kwargs, symbol_mapper=partial(autocast_symbol_mapper, dtype=dtype))
+
+ test_trace = construct_trace()(fn, *args, **kwargs)
+ print(test_trace)
+
return eval_trace(trace, *args, **kwargs, symbol_mapper=partial(autocast_symbol_mapper, dtype=dtype))
would show us the trace:
@torch.no_grad()
@no_autocast
def fn(*args):
# args: "Collection"
t0, t1, t2, = args
t3 = prims.convert_element_type(t0, dtypes.bfloat16) # t3: "cuda:0 bf16[2, 2]"
t4 = prims.convert_element_type(t1, dtypes.bfloat16) # t4: "cuda:0 bf16[2, 2]"
t5 = prims.matmul(t3, t4) # t5: "cuda:0 bf16[2, 2]"
t7 = ltorch.add(t5, t2, alpha=None) # t7: "cuda:0 f32[2, 2]"
# t6 = prims.convert_element_type(t5, dtypes.float32) # t6: "cuda:0 f32[2, 2]"
# t7 = prims.add(t6, t2) # t7: "cuda:0 f32[2, 2]"
return t7
all is good with this trace. Let's jump into the jit interpreter code now. Inspecting the state right before and after this line: https://github.com/Lightning-AI/lightning-thunder/blob/67be468326b49b0f5c0a11df90ff21544d042ac8/thunder/core/jit_ext.py#L1579 reveals the problem. Here's the trace right before L1579:
@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2):
# t_0: "cuda:0 f32[2, 2]"
# t_1: "cuda:0 f32[2, 2]"
# t_2: "cuda:0 f32[2, 2]"
t1 = prims.convert_element_type(t_0, dtypes.bfloat16) # t1: "cuda:0 bf16[2, 2]"
t2 = prims.convert_element_type(t_1, dtypes.bfloat16) # t2: "cuda:0 bf16[2, 2]"
t3 = prims.matmul(t1, t2) # t3: "cuda:0 bf16[2, 2]"
t5 = ltorch.add(t3, t_2, alpha=None) # t5: "cuda:0 f32[2, 2]"
# t4 = prims.convert_element_type(t3, dtypes.float32) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.add(t4, t_2) # t5: "cuda:0 f32[2, 2]"
return t5
and here's one right after that line:
@torch.no_grad()
@no_autocast
def computation(a, b, t_2):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
# t_2: "cuda:0 f32[2, 2]"
tree = prims.convert_element_type(a, dtypes.bfloat16) # tree: "cuda:0 f32[2, 2]"
t2 = prims.convert_element_type(b, dtypes.bfloat16) # t2: "cuda:0 bf16[2, 2]"
result = prims.matmul(tree, t2) # result: "cuda:0 bf16[2, 2]"
t5 = ltorch.add(result, t_2, alpha=None) # t5: "cuda:0 f32[2, 2]"
# t4 = prims.convert_element_type(result, dtypes.float32) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.add(t4, t_2) # t5: "cuda:0 f32[2, 2]"
return t5
@t-vi, what's happening at that line? Just commenting this line out makes the test pass.
There's a similar issue here: https://github.com/Lightning-AI/lightning-thunder/issues/283 cc: @IvanYashchuk @t-vi
I missed that this problem described in the post above was already reported in https://github.com/Lightning-AI/lightning-thunder/issues/283
Here's a diff that should remove thunder.compile
usage in one test:
diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py
index e017ec20..19eb1847 100644
--- a/thunder/tests/test_nvfuser_remat.py
+++ b/thunder/tests/test_nvfuser_remat.py
@@ -62,7 +62,9 @@ def disable_rematerialization_in_nvfuser_fusion(func):
def test_find_producer_symbols(executor, device, _):
# We will try to find a subgraph for rematerializing __c and __d
t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
- compiled_func = thunder.compile(func, disable_preprocessing=True)
+ initial_trace = thunder.trace()(func, t0)
+ compiled_func = thunder.jit(initial_trace.python_callable())
+ # compiled_func = thunder.compile(func, disable_preprocessing=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
but it's currently blocked by https://github.com/Lightning-AI/lightning-thunder/issues/831.
closed by #1114 built on #837
🐛 Bug
All
thunder.compile
usage in the repo should be replaced with equivalentthunder.jit
.Two places in the test framework use
thunder.compile
today, the tests should be rewritten to usethunder.jit
instead. https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/tests/framework.py#L134-L137 https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/tests/framework.py#L152-L157