pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.07k stars 22.41k forks source link

`aot_eager` does not error on mismatching data-types in out tensors. #137213

Open ysiraichi opened 2 weeks ago

ysiraichi commented 2 weeks ago

🐛 Describe the bug

There is an inconsistency when running torch.aminmax out variant on eager and inductor. This behavior is also found when using aot_eager backend.

def run(op):
    inp = torch.rand(10, 10)
    out1 = torch.empty(10, dtype=torch.float64)
    out2 = torch.empty(10, dtype=torch.float64)
    op(inp, dim=-1, out=(out1, out2))

>>> run(torch.aminmax)
Traceback (most recent call last):
  File "examples/inductor.py", line 26, in <module>
    run(torch.aminmax)
  File "examples/inductor.py", line 24, in run
    op(inp, dim=-1, out=(out1, out2))
RuntimeError: Expected out tensor to have dtype float, but got double instead

>>> run(torch.compile(torch.aminmax))
torch.return_types.aminmax_out(min=..., max=...)

Versions

PyTorch version: 2.5.0a0+git7128504 Is debug build: True CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

cc @ezyang @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225

ysiraichi commented 2 weeks ago

Note that this test case does exist in the error_inputs_aminmax_amax_amin function. However, I think the test passes (even with PYTORCH_TEST_WITH_INDUCTOR set) due to a graph break:

$ TORCH_LOGS=+dynamo PYTORCH_TEST_WITH_INDUCTOR=1 python test/test_ops.py -v -k test_errors_aminmax_cpu
...
torch/_dynamo/logging.py:57] [0/0] Step 1: torchdynamo start tracing inner pytorch/torch/_dynamo/external_utils.py:38
torch/fx/experimental/symbolic_shapes.py:2500] [0/0] create_env
torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line pytorch/torch/_dynamo/external_utils.py:40 in inner (wrap_inline.inner)
torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF fn []
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST args [LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_MAP 0 [LazyVariableTracker(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST kwargs [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE DICT_MERGE 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_EX 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] Graph break: from user code at:
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/external_utils.py", line 40, in inner
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] Traceback (most recent call last):
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     return inner_fn(self, inst)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     self.call_function(fn, argsvars.items, kwargsvars)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/symbolic_convert.py", line 830, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     return getattr(self.realize(), name)(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/variables/user_defined.py", line 887, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     return func_var.call_function(tx, [obj_var] + args, kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/variables/functions.py", line 730, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     unimplemented(msg)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]   File "pytorch/torch/_dynamo/exc.py", line 289, in unimplemented
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks]     raise Unsupported(msg, case_name=case_name)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] torch._dynamo.exc.Unsupported: 'skip function TestCase.run in file /usr/local/lib/python3.10/unittest/case.py'
torch/_dynamo/convert_frame.py:699] [0/0] Restarting analysis due to _dynamo/symbolic_convert.py:175 in fail_and_restart_analysis
torch/_dynamo/logging.py:57] [0/0_1] Step 1: torchdynamo start tracing inner pytorch/torch/_dynamo/external_utils.py:38
torch/fx/experimental/symbolic_shapes.py:2500] [0/0_1] create_env
torch/_dynamo/symbolic_convert.py:865] [0/0_1] [__trace_source] TRACE starts_line pytorch/torch/_dynamo/external_utils.py:40 in inner (wrap_inline.inner)
torch/_dynamo/symbolic_convert.py:865] [0/0_1] [__trace_source]             return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_DEREF fn []
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_FAST args [LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE BUILD_MAP 0 [LazyVariableTracker(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_FAST kwargs [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE DICT_MERGE 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE CALL_FUNCTION_EX 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/output_graph.py:991] [0/0_1] COMPILING GRAPH due to GraphCompileReason(reason="'skip function TestCase.run in file /usr/local/lib/python3.10/unittest/case.py'", user_stack=[<FrameSummary file pytorch/torch/_dynamo/external_utils.py, line 40 in inner>], graph_break=True)
ysiraichi commented 2 weeks ago

@bdhirsh

Maybe AOTAutograd functionalization is to blame, here. Basically, aminmax.out becomes aminmax.default, and the returned output is copied back to the arguments. This feels like a problem not aminmax specific.

def forward(self, arg0_1: "f64[10][1]cpu", arg1_1: "f64[10][1]cpu", arg2_1: "f32[10, 10][10, 1]cpu"):
     # File: torch/_dynamo/external_utils.py:40 in inner, code: return fn(*args, **kwargs)
    aminmax = torch.ops.aten.aminmax.default(arg2_1, dim = -1);  arg2_1 = None
    getitem: "f32[10][1]cpu" = aminmax[0]
    getitem_1: "f32[10][1]cpu" = aminmax[1];  aminmax = None
    _to_copy: "f64[10][1]cpu" = torch.ops.aten._to_copy.default(getitem, dtype = torch.float64, layout = torch.strided);  getitem = None
    _to_copy_1: "f64[10][1]cpu" = torch.ops.aten._to_copy.default(getitem_1, dtype = torch.float64, layout = torch.strided);  getitem_1 = None
    copy_: "f64[10][1]cpu" = torch.ops.aten.copy_.default(arg0_1, _to_copy);  arg0_1 = _to_copy = None
    copy__1: "f64[10][1]cpu" = torch.ops.aten.copy_.default(arg1_1, _to_copy_1);  arg1_1 = _to_copy_1 = None
    return (copy_, copy__1)
ysiraichi commented 2 weeks ago

As expected, this is not aminmax specific issue. I have run a test using OpInfo, comparing eager with aot_eager.

@ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
def test_aot_out(self, device, dtype, op):
    samples = op.sample_inputs(device, dtype)

    for sample in samples:

        def op_out(out):
            return op(sample.input, *sample.args, **sample.kwargs, out=out)

        # Compute what the output looks like.
        expected = op(sample.input, *sample.args, **sample.kwargs)

        if isinstance(expected, tuple):
            # Turn tuple-like structures into actual tuples.
            # Otherwise, dynamo will error.
            expected = tuple(expected)

        # Convert the outputs into double dtype.
        eager_out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, dtype=torch.float64), expected)
        eager_fail = False

        aot_eager_out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, dtype=torch.float64), expected)
        aot_eager_fail = False

        # Run both eager and aot_eager with the out= argument.
        # They should agree on either: error or success.

        try:
            op_out(out=eager_out)
        except:
            eager_fail = True

        try:
            torch.compile(op_out, backend="aot_eager")(aot_eager_out)
        except:
            aot_eager_fail = True

        self.assertEqual(eager_fail, aot_eager_fail)

        if not eager_fail:
            # If they do agree, check if they are equal.
            self.assertEqual(eager_out, aot_eager_out)
Full List of Operations with the Issue ``` abs addbmm addmm addmm_decomposed addmv alias_copy all amax amin aminmax any as_strided_copy baddbmm bucketize ceil conj_physical cross cummax cummin diag diagonal_copy dot expand_copy floor frac frexp heaviside histc index_add index_copy index_select isin isneginf isposinf kthvalue lerp linalg_cond linalg_cross linalg_eigh linalg_eigvalsh linalg_ldl_factor linalg_ldl_factor_ex linalg_ldl_solve linalg_lstsq linalg_lu linalg_lu_factor linalg_lu_factor_ex linalg_lu_solve linalg_matrix_power linalg_matrix_rank linalg_matrix_rank_hermitian linalg_pinv linalg_pinv_hermitian linalg_qr linalg_slogdet linalg_solve linalg_solve_ex linalg_solve_triangular log_softmax logcumsumexp lu_solve lu_unpack matmul mean mm mode msort mv nan_to_num nanmean narrow_copy native_batch_norm neg nn_functional_avg_pool3d nn_functional_gelu nn_functional_hardshrink nn_functional_linear nn_functional_logsigmoid nn_functional_softplus nn_functional_softshrink qr renorm round round_decimals_0 scatter_reduce_amax scatter_reduce_amin scatter_reduce_mean scatter_reduce_prod scatter_reduce_sum searchsorted sgn sign signbit slice_scatter softmax sort square t_copy take tril triu trunc unfold_copy unsqueeze_copy vdot view_copy where ```

Specifically 4 operators fail the last check:

linalg_cond
linalg_lstsq
linalg_matrix_rank
linalg_matrix_rank_hermitian
ysiraichi commented 1 week ago

After a bit of digging, I think I found the problem: FunctionalTensorWrapper::replace_ calls at::_to_copy if the data-type of the given out= tensor and the output of the functional call mismatch.

According to the Dev FAQ, "For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source".

Proposed Solution: modify _gen_functionalizationtype.py so that we generate a dtype check on out= functions, depending on whether the operation does participate in type promotion.

@bdhirsh @eellison @zou3519 Let me know what you think.

bdhirsh commented 1 week ago

We have something similar for the Functionalization kernels for inplace -> functional (but not out= -> functional): https://github.com/pytorch/pytorch/blob/9b2e453e246ea44f2071dfdac86ec4b0037a51a5/torchgen/gen_functionalization_type.py#L719

Where for all of our inplace ops, the functionalization rule will first run the underlying inplace op with fresh meta tensors to ensure that any inplace-specific error checks occur first. What do you think of doing something similar for out= ops?

(alternatively, your version works too and would probably be a bit lower overhead, although it comes at the cost of needing to maintain an accurate list of all ops that participate in type promotion when we have existing meta tensor rules we could take advantage of)

ysiraichi commented 1 week ago

That does sound better. But I thought that there was a reason why it wasn't done for out= functions. So, just out of curiosity: do you know why is it done only for inplace functions?

bdhirsh commented 1 week ago

I don't think there was a fundamental reason I never did it for out= (PR here: https://github.com/pytorch/pytorch/pull/77125 from the description, I found a specific instance of an inplace op with different dtype promotion rules than its functional variant, but didn't immediately realize that out= ops would exist the same problem)