Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

`fuse_bound_symbols` (toposort on bsyms) can put `return` before `copy_` on arguments #912

Closed shino16 closed 2 weeks ago

shino16 commented 1 month ago

πŸ› Bug

fuse_bound_symbols (toposort on bsyms) puts return before copy_ on arguments.

This causes AssertionError when torchcompile or cudagraphex executor is applied on in-place operations.

Code sample

def f(a):
    a += 1

jitted = thunder.jit(f, executors=[thunder.get_executor("torchcompile")])
jitted(torch.tensor(0))
...
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 873, in transform_to_torch_types
    assert last.sym.id == prims.PrimIDs.RETURN
AssertionError

Before reaching this line, TorchCompileExecution generates

# Constructed by Transform for execution (took 3 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cpu i64[]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:31:      a += 1
  return None
  TorchCompile0(a)
    # t0 = ltorch.add(a, 1, alpha=None)  # t0: "cpu i64[]"
      # t0 = prims.add(a, 1)  # t0: "cpu i64[]"
    # prims.copy_(t0, a)

Returning a at the end of f does not fix this.

# Constructed by Transform for execution (took 3 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cpu i64[]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:5:       a += 1
  return a
  TorchCompile0(a)
    # t0 = ltorch.add(a, 1, alpha=None)  # t0: "cpu i64[]"
      # t0 = prims.add(a, 1)  # t0: "cpu i64[]"
    # prims.copy_(t0, a)

Cause

Inside fuse_bound_symbols, Graph.__init__ constructs a dependency graph and apply topological sort to reorder bound symbols. The dependency is found via bsym.flat_args and bsym.flat_outs, and prims.copy_ has no output. Hence the algorithm does not spot the dependency between copy_(t0, a) and return a.

Although nvfuser uses this function too, it gets around this by forcefully moving the return statement after all the other bound symbols (here).

Related issue

229 suggested adding TorchCompileExecutor to the tests along with TorchExecutor and nvFuserExecutor.

shino16 commented 1 month ago

test_inplace_to_arg_return_value in thunder/tests/test_inplace_functionalization.py currently fails if you give executors=(thunder.tests.framework.TorchCompileExecutor,).

IvanYashchuk commented 1 month ago

Having prims.copy_(t0, a) without any output in the trace should not be considered a valid program to be executed. The return symbol could take fake inputs that are never returned in Python execution only to build a dependency in the DAG. I expect the computation trace to be something like:

def computation(a):
  # a: "cpu i64[]"
  t1 = TorchCompile0(a)
    # t0 = ltorch.add(a, 1, alpha=None)  # t0: "cpu i64[]"
      # t0 = prims.add(a, 1)  # t0: "cpu i64[]"
    # t1 = prims.copy_(t0, a)
  return None # t1 <--- "t1" is a hidden output used only for building dataflow graph
shino16 commented 1 month ago

Hi! It would be possible for the functionalizer to return the outputs of copy_s as a 'hidden' argument of return. Can I ask the reason why you think having copy_ with no output (and handling it as a special case) is a bad idea?

IvanYashchuk commented 1 month ago

Using copy_ (and any other in-place operation) with no output is not a bad idea in general. In PyTorch Eager, there's nothing that can reorder operations but any trace transformation pass is free to do so in Thunder. All in-place operations in the trace should follow the relative ordering that is prescribed in the initial user script. Handling copy_ as a special case in data_dependent_partition.py might not be ideal because:

What do you think? I'm open to further discussion of alternative ideas or specific use cases in mind that might benefit from a different approach.

shino16 commented 1 month ago

The first two reasons particularly made strong sense to me. If we're not relating copies and return in a sorter code, then a reasonable choice would be to do so as soon as functionalization generates prims.copy_.

I'm working on implementing this!