Closed thiagocrepaldi closed 1 year ago
Although torch._dynamo's export(f, *args, aten_graph=False, decomposition_table=None, tracing_mode='real', **kwargs) does allow **kwargs on its public API, the following snippet
torch._dynamo
export(f, *args, aten_graph=False, decomposition_table=None, tracing_mode='real', **kwargs)
**kwargs
dynamo.config.verbose = True # torch._dynamo.config.suppress_errors = True # dynamo.config.dynamic_shapes = False dynamo.config.log_level = logging.DEBUG aten_graph = False def fn_with_kwargs(**kwargs): input0 = kwargs["input0"] input1 = kwargs["input1"] output0 = input0 * input1 return output0 kwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} graph, _ = torch._dynamo.export(fn_with_kwargs, backend="eager", aten_graph=aten_graph, **kwargs) result_true = graph(**kwargs)
Looking export's implementation, there is a TODO about limitations on kwargs flattening
export
kwargs
# TODO(voz): Handle kwargs properly? flat_args, in_spec = pytree.tree_flatten(args)
which is used to flatten the input to produce_matching, which started the traceback above.
produce_matching
ps: I would love to contrib this fix myself, but opted to create an issue to collect feedback/ideas
[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /opt/conda/lib/python3.9/contextlib.py [2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /opt/conda/lib/python3.9/contextlib.py [2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /opt/conda/lib/python3.9/contextlib.py [2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /opt/conda/lib/python3.9/contextlib.py [2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py [2022-12-15 21:37:00,723] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing fn_with_kwargs [2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:38 [2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST kwargs [] [2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST input0 [ConstDictVariable()] [2022-12-15 21:37:00,725] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_SUBSCR None [ConstDictVariable(), ConstantVariable(str)] [2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST input0 [TensorVariable()] [2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:39 [2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_GLOBAL torch [] [2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR randn_like [TorchVariable(<module 'torch' from '/opt/conda/lib/python3.9/site-packages/torch/__init__.py'>)] [2022-12-15 21:37:00,727] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input0 [TorchVariable(<built-in method randn_like of type object at 0x7f0b8d786540>)] [2022-12-15 21:37:00,727] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [TorchVariable(<built-in method randn_like of type object at 0x7f0b8d786540>), TensorVariable()] [2022-12-15 21:37:00,729] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST input1 [TensorVariable()] [2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:40 [2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input0 [] [2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input1 [TensorVariable()] [2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_MULTIPLY None [TensorVariable(), TensorVariable()] [2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST output0 [TensorVariable()] [2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:41 [2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST output0 [] [2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()] [2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing fn_with_kwargs [2022-12-15 21:37:00,734] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler Traceback (most recent call last): File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 269, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.2", line 5, in forward randn_like = torch.randn_like(kwargs_input0_) TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list Call using an FX-traced Module, line 5 of the traced Module's generated forward function: def forward(self, kwargs_input0_ : torch.Tensor): randn_like = torch.randn_like(kwargs_input0_) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE mul = kwargs_input0_ * randn_like; kwargs_input0_ = randn_like = None return (mul,) [2022-12-15 21:37:00,783] torch._dynamo.debug_utils: [WARNING] Compiled Fx GraphModule failed. Creating script to minify the error. [2022-12-15 21:37:00,785] torch._dynamo.debug_utils: [WARNING] Writing minified repro to /workspace/torchdynamo_debug/run_2022_12_15_21_37_00_785300/minifier/minifier_launcher.py Traceback (most recent call last): File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 624, in call_user_compiler compiled_fn = compiler_fn(gm, self.fake_example_inputs()) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 894, in debug_wrapper run_fwd_maybe_bwd(compiled_gm, example_inputs) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 526, in run_fwd_maybe_bwd out = gm(args) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 600, in result_capturing_wrapper graph_captured_result = graph(*graph_inputs) File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 660, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 277, in __call__ raise e.with_traceback(None) TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/workspace/dynamo_kwargs.py", line 54, in <module> test_kwargs() File "/workspace/dynamo_kwargs.py", line 47, in test_kwargs graph, _ = torch._dynamo.export(fn_with_kwargs, File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 616, in export result_traced = opt_f(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn return fn(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors return callback(frame, cache_size, hooks) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn return fn(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 90, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert return _compile( File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 398, in _compile out_code = transform_code_object(code, transform) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object transformations(instructions, code_options) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 385, in transform tracer.run() File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1676, in run super().run() File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 528, in run and self.step() File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 496, in step getattr(self, inst.opname)(inst) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1738, in RETURN_VALUE self.output.compile_subgraph(self) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 477, in compile_subgraph self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 548, in compile_and_call_fx_graph compiled_fn = self.call_user_compiler(gm) File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 629, in call_user_compiler raise BackendCompilerFailed(self.compiler_fn, e) from e torch._dynamo.exc.BackendCompilerFailed: dynamo_normalization_capturing_compiler raised TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list Minifier script written to /workspace/torchdynamo_debug/run_2022_12_15_21_37_00_785300/minifier/minifier_launcher.py. Run this script to find the smallest traced graph which reproduces this error. You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True
import os from math import inf import torch from torch import tensor, device import torch.fx as fx import functools import torch._dynamo from torch._dynamo.debug_utils import run_fwd_maybe_bwd from torch._dynamo.optimizations.backends import BACKENDS from torch._dynamo.testing import rand_strided # REPLACEABLE COMMENT FOR TESTING PURPOSES args = [((4,), (1,), torch.float32, 'cpu', False)] args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] from torch.nn import * class Repro(torch.nn.Module): def __init__(self): super().__init__() def forward(self, kwargs_input0_ : torch.Tensor): randn_like = torch.randn_like(kwargs_input0_) mul = kwargs_input0_ * randn_like; kwargs_input0_ = randn_like = None return (mul,) mod = Repro() # Setup debug minifier compiler torch._dynamo.debug_utils.MINIFIER_SPAWNED = True compiler_fn = BACKENDS["dynamo_minifier_backend"] raise RuntimeError( 'Compiler name is None - this likely means that a custom compiler ' 'was called by torchdynamo. Please remove this error, import your ' 'custom compiler function, and replace the compiler_name="None" ' 'line below to compiler_name=<my_imported_custom_function>' ) dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="None", ) opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) with torch.cuda.amp.autocast(enabled=False): opt_mod(*args) (base) root@6639b5aac59b:/workspace#
Please feel free to give it a try. cc @voznesenskym for any suggestion you may have.
PR with the fix https://github.com/pytorch/pytorch/pull/92013
š Describe the bug
Although
torch._dynamo
'sexport(f, *args, aten_graph=False, decomposition_table=None, tracing_mode='real', **kwargs)
does allow**kwargs
on its public API, the following snippetLooking
export
's implementation, there is a TODO about limitations onkwargs
flatteningwhich is used to flatten the input to
produce_matching
, which started the traceback above.ps: I would love to contrib this fix myself, but opted to create an issue to collect feedback/ideas
Error logs
Minified repro