Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

functionalization fails to replace in-place operands with new tensors #652

Closed crcrpar closed 3 days ago

crcrpar commented 3 days ago

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

šŸ› Bug

Functionalization fails to update some signatures in replacing in-place operands with new tensors.

Code sample

assert_close(e, e_) in the snippet below fails with

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py", line 33, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py", line 29, in main
    torch.testing.assert_close(e, e_)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/testing/_comparison.py", line 1524, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 0.4451824426651001 at index (0,) (up to 1e-05 allowed)
Greatest relative difference: 0.4193967282772064 at index (0,) (up to 1.3e-06 allowed)
import torch
import thunder

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    c = torch.exp(a)
    d = torch.tanh(b)

    e = c.view(-1)
    e.add_(d.flatten())

    d.div_(a)
    return c, d, e / 2.0

def main():
    a, b = [torch.randn((2, 2), device="cuda", requires_grad=False) for _ in range(2)]
    a_, b_ = a.clone().detach(), b.clone().detach()
    jitted_f = thunder.jit(f)
    c, d, e = jitted_f(a, b)
    c_, d_, e_ = f(a_, b_)

    traces: list[thunder.core.trace.TraceCtx] = thunder.last_traces(jitted_f)
    print(traces[0])
    print(traces[2])

    torch.testing.assert_close(c, c_)
    torch.testing.assert_close(d, d_)
    torch.testing.assert_close(e, e_)

if __name__ == "__main__":
    main()

Traces before/after functionalization are as follows:

def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:6:       c = torch.exp(a)
  c = ltorch.exp(a)  # c: "cuda:0 f32[2, 2]"
    # c = prims.exp(a)  # c: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:7:       d = torch.tanh(b)
  d = ltorch.tanh(b)  # d: "cuda:0 f32[2, 2]"
    # d = prims.tanh(b)  # d: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:9:       e = c.view(-1)
  e = ltorch.view(c, -1)  # e: "cuda:0 f32[4]"
    # e = ltorch.reshape(c, (-1,))  # e: "cuda:0 f32[4]"
      # e = prims.reshape(c, (4,))  # e: "cuda:0 f32[4]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:10:      e.add_(d.flatten())
  t3 = ltorch.flatten(d, 0, -1)  # t3: "cuda:0 f32[4]"
    # t3 = prims.reshape(d, (4,))  # t3: "cuda:0 f32[4]"
  t5 = ltorch.add_(e, t3, alpha=None)  # t5: "cuda:0 f32[4]"
    # t4 = ltorch.add(e, t3, alpha=None)  # t4: "cuda:0 f32[4]"
      # t4 = prims.add(e, t3)  # t4: "cuda:0 f32[4]"
    # t5 = prims.copy_(t4, e)  # t5: "cuda:0 f32[4]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:12:      d.div_(a)
  t7 = ltorch.div_(d, a, rounding_mode=None)  # t7: "cuda:0 f32[2, 2]"
    # t6 = ltorch.div(d, a, rounding_mode=None, out=None)  # t6: "cuda:0 f32[2, 2]"
      # t6 = ltorch.true_divide(d, a)  # t6: "cuda:0 f32[2, 2]"
        # t6 = prims.div(d, a)  # t6: "cuda:0 f32[2, 2]"
    # t7 = prims.copy_(t6, d)  # t7: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:13:      return c, d, e / 2.0
  result = ltorch.true_divide(e, 2.0)  # result: "cuda:0 f32[4]"
    # result = prims.div(e, 2.0)  # result: "cuda:0 f32[4]"
  return (c, d, result)
def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:6:       c = torch.exp(a)
  c = ltorch.exp(a)  # c: "cuda:0 f32[2, 2]"
    # c = prims.exp(a)  # c: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:7:       d = torch.tanh(b)
  d = ltorch.tanh(b)  # d: "cuda:0 f32[2, 2]"
    # d = prims.tanh(b)  # d: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:9:       e = c.view(-1)
  e = ltorch.view(c, -1)  # e: "cuda:0 f32[4]"
    # e = ltorch.reshape(c, (-1,))  # e: "cuda:0 f32[4]"
      # e = prims.reshape(c, (4,))  # e: "cuda:0 f32[4]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:10:      e.add_(d.flatten())
  t3 = ltorch.flatten(d, 0, -1)  # t3: "cuda:0 f32[4]"
    # t3 = prims.reshape(d, (4,))  # t3: "cuda:0 f32[4]"
  t4 = ltorch.add(e, t3, alpha=None)  # t4: "cuda:0 f32[4]"
    # t4 = ltorch.add(e, t3, alpha=None)  # t4: "cuda:0 f32[4]"
      # t4 = prims.add(e, t3)  # t4: "cuda:0 f32[4]"
  t6 = ltorch.div(d, a, rounding_mode=None, out=None)  # t6: "cuda:0 f32[2, 2]"
    # t6 = ltorch.div(d, a, rounding_mode=None, out=None)  # t6: "cuda:0 f32[2, 2]"
      # t6 = ltorch.true_divide(d, a)  # t6: "cuda:0 f32[2, 2]"
        # t6 = prims.div(d, a)  # t6: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:13:      return c, d, e / 2.0
  result = ltorch.true_divide(e, 2.0)  # result: "cuda:0 f32[4]"
    # result = prims.div(e, 2.0)  # result: "cuda:0 f32[4]"
  t9 = prims.reshape(t4, (2, 2))  # t9: "cuda:0 f32[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/issue_foobarbaz.py:13:      return c, d, e / 2.0
  return (t9, t6, result)

result = ltorch.true_divide(e, 2.0) should be result = ltorch.true_divide(t4, 2.0)