Closed mpatel31415 closed 2 weeks ago
Applying the patch from https://github.com/Lightning-AI/lightning-thunder/pull/1075 allows this to run
cc: @IvanYashchuk
What are the unsupported operations that are executed with a fallback in the splitter?
These are the reasons for split -
Honestly, I think the splits are due to a bug in is_node_supported
and that allows for a failing graph to work because of splitting😅
On digging a bit more, I found the culprit subgraph and here is the minimal repro.
import torch
import thunder
from torch import device
def forward(L_stack0_: "bf16[32768, 8]", L_x_: "bf16[32768, 4096]"):
l_stack0_ = L_stack0_
l_x_ = L_x_
topk = torch.topk(l_stack0_, 2); l_stack0_ = None
probs: "bf16[32768, 2]" = topk[0]
indices: "i64[32768, 2]" = topk[1]; topk = None
softmax: "f32[32768, 2]" = probs.softmax(dim = 1, dtype = torch.float32); probs = None
return softmax
jforward = thunder.jit(forward)
L_stack0_ = torch.randn([32768, 8], dtype=torch.bfloat16, device='cuda', requires_grad=True)
L_x_ = torch.randn([32768, 4096], dtype=torch.bfloat16, device='cuda', requires_grad=True)
jforward(L_stack0_, L_x_)
@IvanYashchuk
While investigating this, I think I have stumbled upon a rematerialization issue. Is my understanding correct that for rematerialization, there is an implicit expectation that the forward and backward trace shouldn't have tensor proxy name collision except for the proxies passed as saved_for_backward?
Context - For the script below, after applying patch from https://github.com/Lightning-AI/lightning-thunder/pull/1067 (without this I see another remat error)
import torch
import thunder
from torch import device
def forward(L_stack0_: "bf16[32768, 8]", L_x_: "bf16[32768, 4096]"):
l_stack0_ = L_stack0_
probs = l_stack0_
softmax: "f32[32768, 2]" = probs.softmax(dim = 1, dtype=torch.float); probs = None
return softmax
jforward = thunder.jit(forward)
L_stack0_ = torch.randn([32768, 8], dtype=torch.bfloat16, device='cuda', requires_grad=True)
L_x_ = torch.randn([32768, 4096], dtype=torch.bfloat16, device='cuda', requires_grad=True)
jforward(L_stack0_, L_x_)
This is the joint trace generated (before remat)
def joint_fn(args, kwargs, cotangents):
# L_stack0_: "cuda:0 bf16[32768, 8]"
[t9] = nvFusion0(L_stack0_)
# t0 = prims.convert_element_type(L_stack0_, dtypes.float32) # t0: "cuda:0 f32[32768, 8]"
# t1 = prims.amax(t0, (1,)) # t1: "cuda:0 f32[32768]"
# t2 = prims.broadcast_in_dim(t1, [32768, 1], [0]) # t2: "cuda:0 f32[32768, 1]"
# t3 = prims.broadcast_in_dim(t2, (32768, 8), (0, 1)) # t3: "cuda:0 f32[32768, 8]"
# t4 = prims.sub(t0, t3) # t4: "cuda:0 f32[32768, 8]"
# t5 = prims.exp(t4) # t5: "cuda:0 f32[32768, 8]"
# t6 = prims.sum(t5, (1,)) # t6: "cuda:0 f32[32768]"
# t7 = prims.broadcast_in_dim(t6, [32768, 1], [0]) # t7: "cuda:0 f32[32768, 1]"
# t8 = prims.broadcast_in_dim(t7, (32768, 8), (0, 1)) # t8: "cuda:0 f32[32768, 8]"
# t9 = prims.div(t5, t8) # t9: "cuda:0 f32[32768, 8]"
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
t10, = cotangents
t9, = C0
[t7] = nvFusion0(t9, t10)
# t2 = prims.mul(t9, t10) # t2: "cuda:0 f32[32768, 8]"
# t3 = prims.sum(t2, (1,)) # t3: "cuda:0 f32[32768]"
# t4 = prims.broadcast_in_dim(t3, [32768, 1], [0]) # t4: "cuda:0 f32[32768, 1]"
# t5 = prims.broadcast_in_dim(t4, (32768, 8), (0, 1)) # t5: "cuda:0 f32[32768, 8]"
# t6 = prims.sub(t10, t5) # t6: "cuda:0 f32[32768, 8]"
# t7 = prims.mul(t9, t6) # t7: "cuda:0 f32[32768, 8]"
return {'output': t9, 'flat_args': [L_stack0_], 'flat_output': (t9,)}, ((t7,),)
Remat finds a cut at t4 , we can see that we have t4 in both forward and backward section but both are derived from different computations.
It errors with
And it errors with
File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 610, in rematerialize_forward_and_backward
joint_extrace = rematerialize(joint_extrace)
File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 559, in rematerialize
updated_consumer = apply_rematerialization_for_consumer(current_producer, current_consumer, cut)
File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 193, in apply_rematerialization_for_consumer
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 193, in <lambda>
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
KeyError: 't4'
To verify if name collision is the issue - I added a patch which disabled the names from forward_trace to appear in backward_trace.
diff --git a/thunder/common.py b/thunder/common.py
index 7683f60c..b93987c0 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -1,5 +1,5 @@
import dis
-from typing import Any
+from typing import Any, Set
from collections.abc import Callable, Generator, Hashable, Sequence
from collections import deque, defaultdict
import time
@@ -511,6 +511,7 @@ def trace(
include_return_statement: bool = True,
use_dce: bool = True,
insert_ddp_syncs: bool = False,
+ used_names: Set[str] | None = None
) -> Callable:
@make_opaque
def _trace(
@@ -533,6 +534,9 @@ def trace(
return fn(*args, **kwargs)
trace = TraceCtx(fn)
+ if used_names is not None:
+ for name in used_names:
+ trace.add_name(name)
tracectx_tok = set_tracectx(trace)
proxyargs, proxykwargs = args, kwargs
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 5b056da8..f9b3679c 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -2178,13 +2178,13 @@ def softmax_aug_fwd(a: Proxy, dim: int, dtype: dtypes.dtype | None = None) -> VJ
from thunder.torch import softmax
primal = softmax(a, dim, dtype=dtype)
- residuals = (primal, dim)
+ residuals = (primal, dim, a.dtype)
return VJPDual(primal, residuals)
@register_backward("torch.softmax")
-def softmax_backward(primal, dim, g):
- return primal * (g - (primal * g).sum(dim, keepdim=True))
+def softmax_backward(primal, dim, original_dtype, g):
+ return (primal * (g - (primal * g).sum(dim, keepdim=True))).to(original_dtype)
def iter_bound_symbols(bound_symbols):
@@ -2998,7 +2998,7 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa
out = tree_flatten(out)[0]
return out
- backward_trace = construct_trace(rename_proxies=False)(backward_fn, saved_for_backward, cotangents)
+ backward_trace = construct_trace(rename_proxies=False, used_names=forward_trace.names)(backward_fn, saved_for_backward, cotangents)
# We are done with constructing the forward and backward passes at this
# stage. The following is not strictly necessary, but it's good to filter
With this everything works, even the patch from https://github.com/Lightning-AI/lightning-thunder/pull/1067 is not required (and also fixes the issue that 1067 tries to solve).
EDIT - I tried running test_networks.py
, test_nvfuser_remat.py
and some tests from test_grad.py
and they all worked.
@riccardofelluga do you think https://github.com/Lightning-AI/lightning-thunder/pull/1067 could be related to the above comment
It might be related, however if this is the real issue, it still does not make #1067 obsolete because in any case if that check is valid then we can return early and therefore save some cycles. I find it interesting that you discovered that it's about naming, I couldn't have guessed. My question now would be, what caused the duplicate name in the fist place?
Could you also test the following with your patch and without #1067?
pytest thunder/benchmarks/targets.py -k test_nanogpt_cross_entropy[forward-thunder]
I am not very confident of my understanding of rematerialization, so it might be that my understanding is incorrect.
what caused the duplicate name in the fist place?
I think we don't constrain the names during the backward trace creation (assuming remat requires unique names between traces). Currently, I don't think there is any infra to specify the names that a trace shouldn't take.
Could you also test the following with your patch and without https://github.com/Lightning-AI/lightning-thunder/pull/1067?
Yes, it works. I can run pytest thunder/benchmarks/targets.py -k test_nanogpt_cross_entropy[forward-thunder]
successfully with the patch while it fails on main with KeyError.
Is my understanding correct that for rematerialization, there is an implicit expectation that the forward and backward trace shouldn't have tensor proxy name collision except for the proxies passed as saved_for_backward
When the pass was developed we probably had a global name counter. I think your patch to generate the backward trace without a name clash with the forward trace is a good one! Please submit a pull request.
🐛 Bug
When running the benchmarking we get:
To Reproduce
Please use: 8 nodes, each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl-20240830"
Training script: python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \ --model_name Mixtral-8x7B-v0.1 \ --distributed_mode fsdp \ --shard_mode zero3 \ --compile dynamo_thunder \ --checkpoint_activations True \ --low_precision_mode fp8-delayed-te \ --micro_batch_size 1
Expected behavior
We should not get this error :)
Environment
system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.1.005 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.6 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.10+git58dfdc1 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.5.0a0+git578b8d7 libraries.pip.torchmetrics 1.4.1 libraries.pip.torchvision 0.19.0a0+d23a6e1