Closed jjsjann123 closed 2 months ago
Wondering if @IvanYashchuk has any thought/preference on how we want to fix this?
cc'ing @wujingyue regarding the resharding issue that we just discussed offline.
Thanks for tagging me. This indeed looks like the same symptom I encountered.
For https://github.com/NVIDIA/Fuser/issues/2199, I used to be able to generate a one-nvFusion transformer block backprop by enabling linear, disabling bookend and disabling the cudnn and sdpa executors. It failed on me today, and here's a way to reproduce:
pytest thunder/benchmarks/targets.py -k test_nanogpt_block[backward-thunder] -s
nvFusion
s by the following node. i54
is a NumberProxy. t711 = torch.sum(t688, i54, True, dtype=None) # t711: "cuda:0 bf16[16, 25, 128, 1]"
# t711 = ltorch.sum(t688, i54, True, dtype=None) # t711: "cuda:0 bf16[16, 25, 128, 1]"
# b828 = prims.ge(i54, 0) # b828: "bool False"
# b829 = prims.lt(i54, 0) # b829: "bool True"
# i830 = prims.add(i54, 4) # i830: "int 3"
# b831 = prims.ge(i830, 0) # b831: "bool True"
# b832 = prims.lt(i830, 4) # b832: "bool True"
# t833 = ltorch.to(t688, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t833: "cuda:0 f32[16, 25, 128, 128]"
# t833 = prims.convert_element_type(t688, dtypes.float32) # t833: "cuda:0 f32[16, 25, 128, 128]"
# t840 = prims.sum(t833, (i830,)) # t840: "cuda:0 f32[16, 25, 128]"
# b841 = prims.eq(i830, 0) # b841: "bool False"
# b842 = prims.eq(i830, 1) # b842: "bool False"
# b843 = prims.eq(i830, 2) # b843: "bool False"
# b844 = prims.eq(i830, 3) # b844: "bool True"
# b845 = prims.eq(i830, 0) # b845: "bool False"
# b846 = prims.eq(i830, 1) # b846: "bool False"
# b847 = prims.eq(i830, 2) # b847: "bool False"
# b848 = prims.eq(i830, 3) # b848: "bool True"
# t849 = prims.broadcast_in_dim(t840, [16, 25, 128, 1], [0, 1, 2]) # t849: "cuda:0 f32[16, 25, 128, 1]"
# t711 = ltorch.to(t849, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t711: "cuda:0 bf16[16, 25, 128, 1]"
# t711 = prims.convert_element_type(t849, dtypes.bfloat16) # t711: "cuda:0 bf16[16, 25, 128, 1]"
If a proper fix takes a long time, I'm happy to take a workaround to my local branch to unblock myself.
construct_trace used in grad transform mistakenly converts all saved_for_backward into proxies. We have saved_for_backward=(((t, 0.5), None, ([0.5],)),), which was later translated to bw_flat_saved_for_backward=[t, [FloatProxy name=f0, value=0.5], None, [FloatProxy name=f1, value=0.5]] in backward trace and that is not right. The backward trace is treating that as a proxy f1, instead of baking static numbers in.
Why do you think static numbers should be used for backward trace? How does it make a difference today and what would you like to see when you're confident that Thunder supports NumberProxies properly?
For the example provided in the issue description, it's not this piece of code's fault that ab
does not even appear in the function signature. This code was written to be independent of the "frontend" that acquires the trace. The current frontend for some reason chooses to drop these numbers from function inputs. Let's go back in time to 955e45397c5757ef8d4f6e94f70d410d7f8ebe4d and we would see ab
there:
import torch
import thunder
def foo(t, ab):
return t * ab
jfoo = thunder.compile(foo)
dtype = torch.float32
t = torch.randn(5, 3, device="cuda").to(dtype=dtype)
t.requires_grad_()
ab = 0.5
out = jfoo(t, ab)
print(thunder.last_traces(jfoo)[0][0])
# Constructed by Augmented forward pass
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t, ab):
# t: "cuda:0 f32[5, 3]"
# ab: "float 0.5"
t0 = prims.mul(t, ab) # t0: "cuda:0 f32[5, 3]"
return {'output': t0, 'flat_args': [t, ab], 'flat_output': (t0,)}, ((t,), (0.5,))
Let's put aside the implementation of forward_and_backward_from_trace
. Is the problem that there are two nvFuser regions created in backward (https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2155706083) instead of one. There's no runtime error, right?
I think my example is a bit misleading. I'm not arguing the behavior about the initial trace (this is a cache logic), nor am I arguing about the nvfuser logic (this is merely a side effect from our number proxy handling).
Why do you think static numbers should be used for backward trace? How does it make a difference today and what would you like to see when you're confident that Thunder supports NumberProxies properly?
Static numbers should be used for backward trace here because it's a static number in forward trace. Likewise, if it's indeed a proxy number on forwards trace and used in backward trace, we should have a proxy in bwd trace as well. (This is what @kiya00 needed in #481).
It matters today since we are ramping up number proxy support and we need a consistent behavior.
Let's put aside the implementation of
forward_and_backward_from_trace
In your posted trace, when we have ab
show up on forward trace, it actually looks more scary
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t, ab):
# t: "cuda:0 f32[5, 3]"
# ab: "float 0.5"
t0 = prims.mul(t, ab) # t0: "cuda:0 f32[5, 3]"
return {'output': t0, 'flat_args': [t, ab], 'flat_output': (t0,)}, ((t,), (0.5,))
we are saving 0.5
as a static number for backward, which means if backward uses that number directly, it could generate wrong result when runtime input ab
changes...
This leads to my original question, grad_transform should preserve static
/ dynamic
characteristics of Number/NumberProxy in forward trace, otherwise our cached program can't even guarantee correctness.
Let's put aside the implementation of forward_and_backward_from_trace.
Anyway, maybe forward_and_backward_from_trace
is not the root-cause.
That's what I'm trying to figure out, what function should I be looking at in order to patch grad transform to consistently handle numberproxy between fwd/bwd? Should I also be cautious about impact on rematerialization pass?
FYI, if we indeed want to have numberproxy showing up on inputs, we need to treat that as numberproxy. i.e. using symbolic values
instead in that case.
I'm using this repro for myself as the target for this issue.
import torch
import thunder
def foo(t, ab):
return t * ab * 0.5
jfoo = thunder.jit(foo, cache="symbolic values")
dtype = torch.float32
t = torch.randn(5, 3, device="cuda").to(dtype=dtype)
t_ref = t.detach()
t.requires_grad_()
t_ref.requires_grad_()
out = jfoo(t, 1.5)
out_ref = foo(t_ref, 1.5)
print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])
print("\n\tcompute last trace:\n", thunder.last_traces(jfoo)[-1])
print("\n\tcompute last backward trace:\n", thunder.last_backward_traces(jfoo)[-1])
assert(out.allclose(out_ref))
out.sum().backward()
out_ref.sum().backward()
assert(t.grad.allclose(t_ref.grad))
t.grad = None
t_ref.grad = None
out = jfoo(t, 2.7)
out_ref = foo(t_ref, 2.7)
assert(out.allclose(out_ref))
out.sum().backward()
out_ref.sum().backward()
print(t.grad)
print(t_ref.grad)
assert(t.grad.allclose(t_ref.grad))
what function should I be looking at in order to patch grad transform to consistently handle numberproxy between fwd/bwd?
This is the correct line responsible for unwrapping values from proxies in the forward pass https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3449
If rematerialization breaks when this line is removed it's a bug that needs to be fixed. I don't remember why exactly this line was necessary. Probably something was adding the numberproxies to the trace but never recording the operation that produced it. For the example from https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2159734931
I see this problem
File thunder.augmented_forward_fn_2:11, in augmented_forward_fn(t, ab)
5 @torch.no_grad()
6 @no_autocast
7 def augmented_forward_fn(t, ab):
8 # t: "cuda:0 f32[5, 3]"
9 # ab: "float 1.5"
10 [t1] = nvFusion0(ab, t)
---> 11 return {'output': t1, 'flat_args': [t, ab], 'flat_output': (t1,)}, ((), (ab, f1))
NameError: name 'f1' is not defined
Yan's fix seems like a good workaround https://github.com/Lightning-AI/lightning-thunder/pull/244 and it also fixes the assert. What do you think about it, Jie, should we merge that PR?
Should I also be cautious about impact on rematerialization pass?
No, rematerialization should be working with constant numbers as well as symbolic ones. Jie, could you please provide a failing example? I'll help fix the problem.
we are saving 0.5 as a static number for backward, which means if backward uses that number directly, it could generate wrong result when runtime input ab changes...
The hope is that current Thunder's caching doesn't allow this.
Static numbers should be used for backward trace here because it's a static number in forward trace.
Here's my reasoning: forward and backward functions are separate pure functions. Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified and that's why backward trace uses number proxies independent of whether it was static or symbolic value in forward.
Unwrapping number proxies for forward function output is bad. We/I will fix it.
Jie, do you think that having numberproxies in the backward trace is bad? Do you think we need to change anything?
Running the benchmark following the instructions in https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2155706083 I hit
Traceback (most recent call last):
File "/home/iyashchuk/dev/Fuser/nvfuser/__init__.py", line 146, in execute
result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/Fuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment
I probably need to update the nvFuser installation. Or is this the error that we need to fix?
Here's my reasoning: forward and backward functions are separate pure functions. Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified and that's why backward trace uses number proxies independent of whether it was static or symbolic value in forward. Unwrapping number proxies for forward function output is bad. We/I will fix it. Jie, do you think that having numberproxies in the backward trace is bad? Do you think we need to change anything?
glad to see that we are on the same page here.
Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified
yeah, I think that's where we should have it fixed. I don't think having numberproxies in the backward trace is bad. It's necessary. But it needs to faithfully reflect what they are from the forward trace.
I don't think Yan's PR is enough as-is. But I think it's in the right direction. We can follow up and work together on that one.
Running the benchmark following the instructions in #541 (comment) I hit
Traceback (most recent call last): File "/home/iyashchuk/dev/Fuser/nvfuser/__init__.py", line 146, in execute result = self._execute( RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/Fuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment
I probably need to update the nvFuser installation. Or is this the error that we need to fix?
I forgot which nvFuser version I was using... The benchmark ran fine when I just resynced to https://github.com/NVIDIA/Fuser/commit/b56c3e7960f18e93c2d806d3167c961ebe7f2b20. Can you give it another try?
linking resnet issue regarding grad transform: https://github.com/Lightning-AI/lightning-thunder/pull/451#issuecomment-2221056914
🐛 Bug
construct_trace used in grad transform mistakenly converts all saved_for_backward into proxies.
https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L74-L77 https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3619-L3620
For a simple program like this:
The transform gives us:
We have
saved_for_backward=(((t, 0.5), None, ([0.5],)),)
, which was later translated tobw_flat_saved_for_backward=[t, [FloatProxy name=f0, value=0.5], None, [FloatProxy name=f1, value=0.5]]
in backward trace and that is not right.The backward trace is treating that as a proxy
f1
, instead of baking static numbers in.Context
We need to support NumberProxy passed from fwd to bwd. Currently all NumberProxies are baked in as constant here: https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3449
When we remove that line, it breaks rematerialization pass, since it assumed
saved_for_backward
is consistent between forward and backward. https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/rematerialization.py#L635 https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/rematerialization.py#L643Note that it's using the same
new_required_for_backward
for both forward and backward. In the example above, we'll have fwd trace transformed likeWhere the new_fw_trace is trying to save a
f1
NumberProxy, which is never a number proxy on fwd graph.