Open ysiraichi opened 1 month ago
Note that this test case does exist in the error_inputs_aminmax_amax_amin
function. However, I think the test passes (even with PYTORCH_TEST_WITH_INDUCTOR
set) due to a graph break:
$ TORCH_LOGS=+dynamo PYTORCH_TEST_WITH_INDUCTOR=1 python test/test_ops.py -v -k test_errors_aminmax_cpu
...
torch/_dynamo/logging.py:57] [0/0] Step 1: torchdynamo start tracing inner pytorch/torch/_dynamo/external_utils.py:38
torch/fx/experimental/symbolic_shapes.py:2500] [0/0] create_env
torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line pytorch/torch/_dynamo/external_utils.py:40 in inner (wrap_inline.inner)
torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF fn []
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST args [LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_MAP 0 [LazyVariableTracker(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST kwargs [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE DICT_MERGE 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_EX 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] Graph break: from user code at:
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/external_utils.py", line 40, in inner
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] Traceback (most recent call last):
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] return inner_fn(self, inst)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] self.call_function(fn, argsvars.items, kwargsvars)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/symbolic_convert.py", line 830, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] return getattr(self.realize(), name)(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/variables/user_defined.py", line 887, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] return func_var.call_function(tx, [obj_var] + args, kwargs)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/variables/functions.py", line 730, in call_function
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] unimplemented(msg)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] File "pytorch/torch/_dynamo/exc.py", line 289, in unimplemented
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] raise Unsupported(msg, case_name=case_name)
torch/_dynamo/symbolic_convert.py:617] [0/0] [__graph_breaks] torch._dynamo.exc.Unsupported: 'skip function TestCase.run in file /usr/local/lib/python3.10/unittest/case.py'
torch/_dynamo/convert_frame.py:699] [0/0] Restarting analysis due to _dynamo/symbolic_convert.py:175 in fail_and_restart_analysis
torch/_dynamo/logging.py:57] [0/0_1] Step 1: torchdynamo start tracing inner pytorch/torch/_dynamo/external_utils.py:38
torch/fx/experimental/symbolic_shapes.py:2500] [0/0_1] create_env
torch/_dynamo/symbolic_convert.py:865] [0/0_1] [__trace_source] TRACE starts_line pytorch/torch/_dynamo/external_utils.py:40 in inner (wrap_inline.inner)
torch/_dynamo/symbolic_convert.py:865] [0/0_1] [__trace_source] return fn(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_DEREF fn []
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_FAST args [LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE BUILD_MAP 0 [LazyVariableTracker(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE LOAD_FAST kwargs [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE DICT_MERGE 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable(), LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:888] [0/0_1] [__trace_bytecode] TRACE CALL_FUNCTION_EX 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
torch/_dynamo/output_graph.py:991] [0/0_1] COMPILING GRAPH due to GraphCompileReason(reason="'skip function TestCase.run in file /usr/local/lib/python3.10/unittest/case.py'", user_stack=[<FrameSummary file pytorch/torch/_dynamo/external_utils.py, line 40 in inner>], graph_break=True)
@bdhirsh
Maybe AOTAutograd functionalization is to blame, here. Basically, aminmax.out
becomes aminmax.default
, and the returned output is copied back to the arguments. This feels like a problem not aminmax
specific.
def forward(self, arg0_1: "f64[10][1]cpu", arg1_1: "f64[10][1]cpu", arg2_1: "f32[10, 10][10, 1]cpu"):
# File: torch/_dynamo/external_utils.py:40 in inner, code: return fn(*args, **kwargs)
aminmax = torch.ops.aten.aminmax.default(arg2_1, dim = -1); arg2_1 = None
getitem: "f32[10][1]cpu" = aminmax[0]
getitem_1: "f32[10][1]cpu" = aminmax[1]; aminmax = None
_to_copy: "f64[10][1]cpu" = torch.ops.aten._to_copy.default(getitem, dtype = torch.float64, layout = torch.strided); getitem = None
_to_copy_1: "f64[10][1]cpu" = torch.ops.aten._to_copy.default(getitem_1, dtype = torch.float64, layout = torch.strided); getitem_1 = None
copy_: "f64[10][1]cpu" = torch.ops.aten.copy_.default(arg0_1, _to_copy); arg0_1 = _to_copy = None
copy__1: "f64[10][1]cpu" = torch.ops.aten.copy_.default(arg1_1, _to_copy_1); arg1_1 = _to_copy_1 = None
return (copy_, copy__1)
As expected, this is not aminmax
specific issue. I have run a test using OpInfo
, comparing eager with aot_eager
.
@ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
def test_aot_out(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
for sample in samples:
def op_out(out):
return op(sample.input, *sample.args, **sample.kwargs, out=out)
# Compute what the output looks like.
expected = op(sample.input, *sample.args, **sample.kwargs)
if isinstance(expected, tuple):
# Turn tuple-like structures into actual tuples.
# Otherwise, dynamo will error.
expected = tuple(expected)
# Convert the outputs into double dtype.
eager_out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, dtype=torch.float64), expected)
eager_fail = False
aot_eager_out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, dtype=torch.float64), expected)
aot_eager_fail = False
# Run both eager and aot_eager with the out= argument.
# They should agree on either: error or success.
try:
op_out(out=eager_out)
except:
eager_fail = True
try:
torch.compile(op_out, backend="aot_eager")(aot_eager_out)
except:
aot_eager_fail = True
self.assertEqual(eager_fail, aot_eager_fail)
if not eager_fail:
# If they do agree, check if they are equal.
self.assertEqual(eager_out, aot_eager_out)
Specifically 4 operators fail the last check:
linalg_cond
linalg_lstsq
linalg_matrix_rank
linalg_matrix_rank_hermitian
After a bit of digging, I think I found the problem: FunctionalTensorWrapper::replace_
calls at::_to_copy
if the data-type of the given out=
tensor and the output of the functional call mismatch.
According to the Dev FAQ, "For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source".
Proposed Solution: modify _gen_functionalizationtype.py so that we generate a dtype check on out=
functions, depending on whether the operation does participate in type promotion.
emit_inplace_functionalization_body
function, generate the dtype check if the function being generated is in the list mentioned above@bdhirsh @eellison @zou3519 Let me know what you think.
We have something similar for the Functionalization kernels for inplace
-> functional
(but not out=
-> functional
): https://github.com/pytorch/pytorch/blob/9b2e453e246ea44f2071dfdac86ec4b0037a51a5/torchgen/gen_functionalization_type.py#L719
Where for all of our inplace ops, the functionalization rule will first run the underlying inplace op with fresh meta tensors to ensure that any inplace-specific error checks occur first. What do you think of doing something similar for out= ops?
(alternatively, your version works too and would probably be a bit lower overhead, although it comes at the cost of needing to maintain an accurate list of all ops that participate in type promotion when we have existing meta tensor rules we could take advantage of)
That does sound better. But I thought that there was a reason why it wasn't done for out=
functions. So, just out of curiosity: do you know why is it done only for inplace functions?
I don't think there was a fundamental reason I never did it for out= (PR here: https://github.com/pytorch/pytorch/pull/77125 from the description, I found a specific instance of an inplace op with different dtype promotion rules than its functional variant, but didn't immediately realize that out= ops would exist the same problem)
On second thoughts, I don't think we need this. That's because dynamo should already be running the meta functions over the fake tensors, shouldn't it? Which means that this should be raising an error at that point, before functionalization.
@bdhirsh What do you think?
@ysiraichi yep, that's a good point. Our meta impl for aten.aminmax.out
probably isn't doing the proper error checking: https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L4810
Yeah. I have opened an issue (#138399) listing the operations that are not consistently (w.r.t. eager CPU/CUDA) raising an error on dtype mismatch.
🐛 Describe the bug
There is an inconsistency when running
torch.aminmax
out variant on eager and inductor. This behavior is also found when usingaot_eager
backend.Versions
PyTorch version: 2.5.0a0+git7128504 Is debug build: True CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A
cc @ezyang @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225