Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Copy to the original tensors #585

Closed crcrpar closed 3 months ago

crcrpar commented 3 months ago

note that nvfuser executor does not work for this impl:

import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  [c, d] = nvFusion0(a, b)
    # c = prims.exp(a)  # c: "cuda:0 f32[2, 2]"
    # d = prims.tanh(b)  # d: "cuda:0 f32[2, 2]"
    # t2 = prims.add(c, d)  # t2: "cuda:0 f32[2, 2]"
    # t6 = prims.sub(t5, b)  # t6: "cuda:0 f32[2, 2]"
    # prims.copy_(t2, c)
    # prims.copy_(t6, d)
  del a, b
  return (c, d)
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snipet.py", line 28, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snipet.py", line 19, in main
    c, d = jit_f(a, b)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 662, in fn_
    result = cache_entry.computation_fn(*inps)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "thunder.computation_1", line 10, in computation
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 402, in __call__
    fd = self.get_fd(to_descriptors(args))
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 512, in get_fd
    return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 274, in create_fd
    translate_bound_symbol(bsym)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 264, in translate_bound_symbol
    nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 1842, in sub
    nva = getnv(a, fd, lc_to_nv_map)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 116, in getnv
    return lc_to_nv_map[x]
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/utils.py", line 919, in __getitem__
    return self._dict[key_]
KeyError: 't5'

torch executor only can generate the following trace:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"
  c = torch.exp(a)  # c: "cuda:0 f32[2, 2]"
    # c = ltorch.exp(a)  # c: "cuda:0 f32[2, 2]"
      # c = prims.exp(a)  # c: "cuda:0 f32[2, 2]"
  d = torch.tanh(b)  # d: "cuda:0 f32[2, 2]"
    # d = ltorch.tanh(b)  # d: "cuda:0 f32[2, 2]"
      # d = prims.tanh(b)  # d: "cuda:0 f32[2, 2]"
  t2 = torch.add(c, d)  # t2: "cuda:0 f32[2, 2]"
    # t2 = ltorch.add(c, d, alpha=None)  # t2: "cuda:0 f32[2, 2]"
      # t2 = prims.add(c, d)  # t2: "cuda:0 f32[2, 2]"
  t4 = torch.div(d, a)  # t4: "cuda:0 f32[2, 2]"
    # t4 = ltorch.div(d, a, rounding_mode=None, out=None)  # t4: "cuda:0 f32[2, 2]"
      # t4 = ltorch.true_divide(d, a)  # t4: "cuda:0 f32[2, 2]"
        # t4 = prims.div(d, a)  # t4: "cuda:0 f32[2, 2]"
  del a
  t6 = torch.sub(t4, b)  # t6: "cuda:0 f32[2, 2]"
    # t6 = ltorch.sub(t4, b, alpha=None)  # t6: "cuda:0 f32[2, 2]"
      # t6 = prims.sub(t4, b)  # t6: "cuda:0 f32[2, 2]"
  del t4, b
  copy_(t2, c)
  del t2
  copy_(t6, d)
  del t6
  return (c, d)

the used snippet is:

import torch
import thunder

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    c = torch.exp(a)
    d = torch.tanh(b)
    c += d

    d.div_(a)
    d.sub_(b)
    return c, d

def main():
    a, b = [torch.randn((2, 2), device="cuda", requires_grad=False) for _ in range(2)]
    a_, b_ = a.clone().detach(), b.clone().detach()
    jit_f = thunder.jit(f)  #, executors=[thunder.executors.get_torch_executor()])
    c, d = jit_f(a, b)
    c_, d_ = f(a_, b_)
    print(thunder.last_traces(jit_f)[-1])

    torch.testing.assert_close(c, c_)
    torch.testing.assert_close(d, d_)

if __name__ == "__main__":
    main()