Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

`Tensor.copy_` tries to copy onto an intermediate tensor in a canonicalized trace #1192

Open shino16 opened 2 months ago

shino16 commented 2 months ago
import torch
import thunder

@thunder.jit
def f(x):
    x.add_(1)
    return x.copy_(x.sin())

f(torch.tensor(0.0, device='cuda'))

The above results in the following error from nvFuser.

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 62, in __exit__
    self._finalize_definition()
RuntimeError: input_expr->isA<UnaryOp>() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/fusion.cpp":799, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. expected unary op for aliased input
Exception raised from aliasOutputToInput at /opt/pytorch/nvfuser/csrc/fusion.cpp:799 (most recent call first):
  ...
Fusion definition ```python # CUDA devices: # 0: NVIDIA RTX 6000 Ada Generation # torch version: 2.5.0a0+gitda32021 # cuda version: 12.6 # nvfuser version: 0.2.10+gitc3f8037 import torch from nvfuser import FusionDefinition, DataType def nvfuser_fusion_id0(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False) S1 = fd.define_scalar(1.00000, dtype=DataType.Double) T2 = fd.ops.add(T0, S1) T3 = fd.ops.sin(T2) T4 = fd.ops.set(T2) fd.add_output(T4, T0) T5 = fd.ops.set(T3) fd.add_output(T5, T2) fd.add_output(T0) fd.add_output(T2) with FusionDefinition() as fd: nvfuser_fusion_id0(fd) ```

I am not sure about how to interpret nvFuser's error message, but the problem would be trying to write the output of fd.ops.sin onto the output of fd.ops.add.

def computation(x):
  # x: "cuda:0 f32[]"
  [t1, t3] = nvFusion0(x)
    # t0 = prims.add(x, 1.0)  # t0: "cuda:0 f32[]"
    # t2 = prims.sin(t0)  # t2: "cuda:0 f32[]"
    # t1 = prims.copy_(t0, x)  # t1: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t0)  # t3: "cuda:0 f32[]"
  del x
  return t3

When we use x.neg() instead of x.sin(), the nvFuser executor somehow orders the copy onto t0 before the one from t0, and gets flagged as unsafe by _inplace_copy_sanity_check.

NotImplementedError: t1 = prims.copy_(t0, x)  # t1: "cuda:0 f32[]" trying to use <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=())> (the 'copy_to' argument of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
Python source, Execution trace ```py @thunder.jit def f(x): x.add_(1) return x.copy_(x.neg()) ``` ```py def computation(x): # x: "cuda:0 f32[]" [t3, t1] = nvFusion0(x) # t0 = prims.add(x, 1.0) # t0: "cuda:0 f32[]" # t2 = prims.neg(t0) # t2: "cuda:0 f32[]" # t3 = prims.copy_(t2, t0) # t3: "cuda:0 f32[]" # t1 = prims.copy_(t0, x) # t1: "cuda:0 f32[]" del x return {'output': t3, 'flat_args': [t1]} ```

I presume this will be fixed by functionalizing Tensor.copy_ like other in-place ops, but doing so appropriately would involve somewhat big changes in thunder/core/functionalization.py

nvMelissa commented 1 month ago

@crcrpar - Question from Ivan: what should be the correct behaviour?

crcrpar commented 1 month ago

Embarrassingly I don't remember for what and why I exposed copy_ to thunder.torch. Even I last weak opened https://github.com/Lightning-AI/lightning-thunder/pull/1209