Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Partially support in-place ops and tensor aliases #597

Closed crcrpar closed 3 months ago

crcrpar commented 3 months ago

when

then replace the use of the originals in the later bsyms with the reshaped views that have been in-place updated.

IvanYashchuk commented 3 months ago

Are you attempting to generate a trace similar to this one?

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    c = torch.exp(a)
    d = torch.tanh(b)
    e = c.flatten()
    e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape);
    d.div_(a) # new_d = d.div(a)
    return c, d # new_c, new_d 
crcrpar commented 3 months ago

Are you attempting to generate a trace similar to this one?

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    c = torch.exp(a)
    d = torch.tanh(b)
    e = c.flatten()
    e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape);
    d.div_(a) # new_d = d.div(a)
    return c, d # new_c, new_d 

yes

crcrpar commented 3 months ago

nvfuser executor doesn't seem to love this logic though (as of ad1f2f9)

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 50, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 21, in main
    c, d, e = jit_f(a, b)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 667, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 610, in get_computation_and_inputs
    thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 84, in _inplace_copy_sanity_check
    check(copy_to_arg, "'copy_to' argument")
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 79, in check
    raise NotImplementedError(
NotImplementedError: t8 = prims.reshape(e, (2, 2))  # t8: "cuda:0 f32[2, 2]" trying to use e (the 'copy_to' argument of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
mruberry commented 3 months ago

Let's prioritize a design review for inplace operations, and before the review let's share a document with some cases and examples. I'm not really sure how this would address the challenges of inplace operations

jjsjann123 commented 3 months ago

nvfuser executor doesn't seem to love this logic though (as of ad1f2f9)

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 50, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 21, in main
    c, d, e = jit_f(a, b)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 667, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 610, in get_computation_and_inputs
    thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 84, in _inplace_copy_sanity_check
    check(copy_to_arg, "'copy_to' argument")
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 79, in check
    raise NotImplementedError(
NotImplementedError: t8 = prims.reshape(e, (2, 2))  # t8: "cuda:0 f32[2, 2]" trying to use e (the 'copy_to' argument of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.

The warning is added explicitly (by @kiya00 IIRC).

nvfuser's support of inplace update is limited at this moment. We don't guarantee order of operations in generated kernel, which means if you try to use a tensor that has been updated inside a fusion, you might ended up with a silent result. Hence we decided to throw a warning here.

We can explore WAR in thunder for this. i.e. we can break the fusion region, which ensures the use after update happens in a follow up kernel, so there's no chance for nvfuser to change the memory access sequence. But this also requires thunder to ensure the order of the regions, which I don't think we have any restriction in as of today.

jjsjann123 commented 3 months ago

Are you attempting to generate a trace similar to this one?

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    c = torch.exp(a)
    d = torch.tanh(b)
    e = c.flatten()
    e += d.flatten() # new_e = e + d.flatten(); new_c = reshape(new_e, c.shape);
    d.div_(a) # new_d = d.div(a)
    return c, d # new_c, new_d 

yes

I think we should clear up our target first. In the example here, (as @IvanYashchuk pointed out) the original program could be functionalized as a whole.

Hence the resulting trace should NOT show any copy_ within. Since we are already calling out of place math ops, the memory save from doing inplace is gone. Blindly following the semantics here with copy_ seems to be meaninglessly adding complexity to follow up mutation passes.

I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!

import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  c = torch.exp(a)  # c: "cuda:0 f32[2, 2]"
    # c = ltorch.exp(a)  # c: "cuda:0 f32[2, 2]"
      # c = prims.exp(a)  # c: "cuda:0 f32[2, 2]"
  d = torch.tanh(b)  # d: "cuda:0 f32[2, 2]"
    # d = ltorch.tanh(b)  # d: "cuda:0 f32[2, 2]"
      # d = prims.tanh(b)  # d: "cuda:0 f32[2, 2]"
  del b
  e = torch.flatten(c, 0, -1)  # e: "cuda:0 f32[4]"
    # e = ltorch.flatten(c, 0, -1)  # e: "cuda:0 f32[4]"
      # e = prims.reshape(c, (4,))  # e: "cuda:0 f32[4]"
  t3 = torch.flatten(d, 0, -1)  # t3: "cuda:0 f32[4]"
    # t3 = ltorch.flatten(d, 0, -1)  # t3: "cuda:0 f32[4]"
      # t3 = prims.reshape(d, (4,))  # t3: "cuda:0 f32[4]"
  t4 = torch.add(e, t3)  # t4: "cuda:0 f32[4]"
    # t4 = ltorch.add(e, t3, alpha=None)  # t4: "cuda:0 f32[4]"
      # t4 = prims.add(e, t3)  # t4: "cuda:0 f32[4]"
  del t3
  t5 = copy_(t4, e)  # t5: "cuda:0 f32[4]"
  del t4
  t8 = torch.reshape(e, (2, 2))  # t8: "cuda:0 f32[2, 2]"
    # t8 = ltorch.reshape(e, (2, 2))  # t8: "cuda:0 f32[2, 2]"
      # t8 = prims.reshape(e, (2, 2))  # t8: "cuda:0 f32[2, 2]"
  del e
  t9 = copy_(t8, c)  # t9: "cuda:0 f32[2, 2]"
  del t8, c
  t6 = torch.div(d, a)  # t6: "cuda:0 f32[2, 2]"
    # t6 = ltorch.div(d, a, rounding_mode=None, out=None)  # t6: "cuda:0 f32[2, 2]"
      # t6 = ltorch.true_divide(d, a)  # t6: "cuda:0 f32[2, 2]"
        # t6 = prims.div(d, a)  # t6: "cuda:0 f32[2, 2]"
  del d, a

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py:12:      d.div_(a)
  return (t9, t6, t5)
crcrpar commented 3 months ago

I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!

what would you think is a better approach propagating the in-place op on a view? The trace tries to propagate the in-place update on a view into the original tensor

jjsjann123 commented 3 months ago

I take it as that, the trace we are showing in PR description is just a side effect from that our in-place pass isn't functionalizing all the corner cases yet?!

what would you think is a better approach propagating the in-place op on a view? The trace tries to propagate the in-place update on a view into the original tensor

I'm not saying what you are doing here for view is not right.

I'm just pointing out this is a case showing that our in-place pass is missing optimization opportunities here, where an in-place update on an intermediate tensor isn't being optimized away.

t-vi commented 3 months ago

Let's move additions to a follow-up.