Closed crcrpar closed 3 months ago
Are you attempting to generate a trace similar to this one?
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
c = torch.exp(a)
d = torch.tanh(b)
e = c.flatten()
e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape);
d.div_(a) # new_d = d.div(a)
return c, d # new_c, new_d
Are you attempting to generate a trace similar to this one?
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: c = torch.exp(a) d = torch.tanh(b) e = c.flatten() e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape); d.div_(a) # new_d = d.div(a) return c, d # new_c, new_d
yes
nvfuser executor doesn't seem to love this logic though (as of ad1f2f9)
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 50, in <module>
main()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 21, in main
c, d, e = jit_f(a, b)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 667, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 610, in get_computation_and_inputs
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 84, in _inplace_copy_sanity_check
check(copy_to_arg, "'copy_to' argument")
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 79, in check
raise NotImplementedError(
NotImplementedError: t8 = prims.reshape(e, (2, 2)) # t8: "cuda:0 f32[2, 2]" trying to use e (the 'copy_to' argument of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
Let's prioritize a design review for inplace operations, and before the review let's share a document with some cases and examples. I'm not really sure how this would address the challenges of inplace operations
nvfuser executor doesn't seem to love this logic though (as of ad1f2f9)
Traceback (most recent call last): File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 50, in <module> main() File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 21, in main c, d, e = jit_f(a, b) File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 667, in fn_ cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper res = fn(*args, **kwargs) File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 610, in get_computation_and_inputs thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 84, in _inplace_copy_sanity_check check(copy_to_arg, "'copy_to' argument") File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 79, in check raise NotImplementedError( NotImplementedError: t8 = prims.reshape(e, (2, 2)) # t8: "cuda:0 f32[2, 2]" trying to use e (the 'copy_to' argument of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
The warning is added explicitly (by @kiya00 IIRC).
nvfuser's support of inplace update is limited at this moment. We don't guarantee order of operations in generated kernel, which means if you try to use a tensor that has been updated inside a fusion, you might ended up with a silent result. Hence we decided to throw a warning here.
We can explore WAR in thunder for this. i.e. we can break the fusion region, which ensures the use after update happens in a follow up kernel, so there's no chance for nvfuser to change the memory access sequence. But this also requires thunder to ensure the order of the regions, which I don't think we have any restriction in as of today.
Are you attempting to generate a trace similar to this one?
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: c = torch.exp(a) d = torch.tanh(b) e = c.flatten() e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape); d.div_(a) # new_d = d.div(a) return c, d # new_c, new_d
yes
I think we should clear up our target first. In the example here, (as @IvanYashchuk pointed out) the original program could be functionalized as a whole.
Hence the resulting trace should NOT show any copy_
within. Since we are already calling out of place math ops, the memory save from doing inplace is gone.
Blindly following the semantics here with copy_
seems to be meaninglessly adding complexity to follow up mutation passes.
I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
c = torch.exp(a) # c: "cuda:0 f32[2, 2]"
# c = ltorch.exp(a) # c: "cuda:0 f32[2, 2]"
# c = prims.exp(a) # c: "cuda:0 f32[2, 2]"
d = torch.tanh(b) # d: "cuda:0 f32[2, 2]"
# d = ltorch.tanh(b) # d: "cuda:0 f32[2, 2]"
# d = prims.tanh(b) # d: "cuda:0 f32[2, 2]"
del b
e = torch.flatten(c, 0, -1) # e: "cuda:0 f32[4]"
# e = ltorch.flatten(c, 0, -1) # e: "cuda:0 f32[4]"
# e = prims.reshape(c, (4,)) # e: "cuda:0 f32[4]"
t3 = torch.flatten(d, 0, -1) # t3: "cuda:0 f32[4]"
# t3 = ltorch.flatten(d, 0, -1) # t3: "cuda:0 f32[4]"
# t3 = prims.reshape(d, (4,)) # t3: "cuda:0 f32[4]"
t4 = torch.add(e, t3) # t4: "cuda:0 f32[4]"
# t4 = ltorch.add(e, t3, alpha=None) # t4: "cuda:0 f32[4]"
# t4 = prims.add(e, t3) # t4: "cuda:0 f32[4]"
del t3
t5 = copy_(t4, e) # t5: "cuda:0 f32[4]"
del t4
t8 = torch.reshape(e, (2, 2)) # t8: "cuda:0 f32[2, 2]"
# t8 = ltorch.reshape(e, (2, 2)) # t8: "cuda:0 f32[2, 2]"
# t8 = prims.reshape(e, (2, 2)) # t8: "cuda:0 f32[2, 2]"
del e
t9 = copy_(t8, c) # t9: "cuda:0 f32[2, 2]"
del t8, c
t6 = torch.div(d, a) # t6: "cuda:0 f32[2, 2]"
# t6 = ltorch.div(d, a, rounding_mode=None, out=None) # t6: "cuda:0 f32[2, 2]"
# t6 = ltorch.true_divide(d, a) # t6: "cuda:0 f32[2, 2]"
# t6 = prims.div(d, a) # t6: "cuda:0 f32[2, 2]"
del d, a
# /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py:12: d.div_(a)
return (t9, t6, t5)
I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!
what would you think is a better approach propagating the in-place op on a view? The trace tries to propagate the in-place update on a view into the original tensor
I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!
what would you think is a better approach propagating the in-place op on a view? The trace tries to propagate the in-place update on a view into the original tensor
I'm not saying what you are doing here for view
is not right.
I'm just pointing out this is a case showing that our in-place pass is missing optimization opportunities here, where an in-place update on an intermediate tensor isn't being optimized away.
Let's move additions to a follow-up.
when
trace.args
nortrace.kwargs
then replace the use of the originals in the later bsyms with the reshaped views that have been in-place updated.