pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
995 stars 123 forks source link

Fix support **kwargs for torch._dynamo.export #1997

Closed thiagocrepaldi closed 1 year ago

thiagocrepaldi commented 1 year ago

šŸ› Describe the bug

Although torch._dynamo's export(f, *args, aten_graph=False, decomposition_table=None, tracing_mode='real', **kwargs) does allow **kwargs on its public API, the following snippet

dynamo.config.verbose = True
# torch._dynamo.config.suppress_errors = True
# dynamo.config.dynamic_shapes = False
dynamo.config.log_level = logging.DEBUG
aten_graph = False

def fn_with_kwargs(**kwargs):
    input0 = kwargs["input0"]
    input1 = kwargs["input1"]
    output0 = input0 * input1
    return output0

kwargs = {"input0": torch.randn(4),
            "input1": torch.randn(4)}

graph, _ = torch._dynamo.export(fn_with_kwargs,
                                backend="eager",
                                aten_graph=aten_graph,
                                **kwargs)
result_true = graph(**kwargs)

Looking export's implementation, there is a TODO about limitations on kwargs flattening

    # TODO(voz): Handle kwargs properly?
    flat_args, in_spec = pytree.tree_flatten(args)

which is used to flatten the input to produce_matching, which started the traceback above.

ps: I would love to contrib this fix myself, but opted to create an issue to collect feedback/ideas

Error logs

[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /opt/conda/lib/python3.9/contextlib.py
[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /opt/conda/lib/python3.9/contextlib.py
[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /opt/conda/lib/python3.9/contextlib.py
[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /opt/conda/lib/python3.9/contextlib.py
[2022-12-15 21:37:00,697] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py
[2022-12-15 21:37:00,723] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing fn_with_kwargs
[2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:38
[2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST kwargs []
[2022-12-15 21:37:00,724] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST input0 [ConstDictVariable()]
[2022-12-15 21:37:00,725] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_SUBSCR None [ConstDictVariable(), ConstantVariable(str)]
[2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST input0 [TensorVariable()]
[2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:39
[2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_GLOBAL torch []
[2022-12-15 21:37:00,726] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR randn_like [TorchVariable(<module 'torch' from '/opt/conda/lib/python3.9/site-packages/torch/__init__.py'>)]
[2022-12-15 21:37:00,727] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input0 [TorchVariable(<built-in method randn_like of type object at 0x7f0b8d786540>)]
[2022-12-15 21:37:00,727] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [TorchVariable(<built-in method randn_like of type object at 0x7f0b8d786540>), TensorVariable()]
[2022-12-15 21:37:00,729] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST input1 [TensorVariable()]
[2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:40
[2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input0 []
[2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST input1 [TensorVariable()]
[2022-12-15 21:37:00,730] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_MULTIPLY None [TensorVariable(), TensorVariable()]
[2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST output0 [TensorVariable()]
[2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /workspace/dynamo_kwargs.py:41
[2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST output0 []
[2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()]
[2022-12-15 21:37:00,732] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing fn_with_kwargs
[2022-12-15 21:37:00,734] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 269, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.2", line 5, in forward
    randn_like = torch.randn_like(kwargs_input0_)
TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list

Call using an FX-traced Module, line 5 of the traced Module's generated forward function:
def forward(self, kwargs_input0_ : torch.Tensor):
    randn_like = torch.randn_like(kwargs_input0_)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    mul = kwargs_input0_ * randn_like;  kwargs_input0_ = randn_like = None

    return (mul,)

[2022-12-15 21:37:00,783] torch._dynamo.debug_utils: [WARNING] Compiled Fx GraphModule failed. Creating script to minify the error.
[2022-12-15 21:37:00,785] torch._dynamo.debug_utils: [WARNING] Writing minified repro to /workspace/torchdynamo_debug/run_2022_12_15_21_37_00_785300/minifier/minifier_launcher.py
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 624, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 894, in debug_wrapper
    run_fwd_maybe_bwd(compiled_gm, example_inputs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 526, in run_fwd_maybe_bwd
    out = gm(args)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 600, in result_capturing_wrapper
    graph_captured_result = graph(*graph_inputs)
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 660, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 277, in __call__
    raise e.with_traceback(None)
TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/dynamo_kwargs.py", line 54, in <module>
    test_kwargs()
  File "/workspace/dynamo_kwargs.py", line 47, in test_kwargs
    graph, _ = torch._dynamo.export(fn_with_kwargs,
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 616, in export
    result_traced = opt_f(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 398, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 385, in transform
    tracer.run()
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1676, in run
    super().run()
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 528, in run
    and self.step()
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 496, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1738, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 477, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 548, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 629, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: dynamo_normalization_capturing_compiler raised TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not list

Minifier script written to /workspace/torchdynamo_debug/run_2022_12_15_21_37_00_785300/minifier/minifier_launcher.py. Run this script to find the smallest traced graph which reproduces this error.

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

import os
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import functools
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.optimizations.backends import BACKENDS
from torch._dynamo.testing import rand_strided

# REPLACEABLE COMMENT FOR TESTING PURPOSES

args = [((4,), (1,), torch.float32, 'cpu', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, kwargs_input0_ : torch.Tensor):
        randn_like = torch.randn_like(kwargs_input0_)
        mul = kwargs_input0_ * randn_like;  kwargs_input0_ = randn_like = None
        return (mul,)

mod = Repro()

# Setup debug minifier compiler
torch._dynamo.debug_utils.MINIFIER_SPAWNED = True
compiler_fn = BACKENDS["dynamo_minifier_backend"]
raise RuntimeError(
    'Compiler name is None - this likely means that a custom compiler '
    'was called by torchdynamo. Please remove this error, import your '
    'custom compiler function, and replace the compiler_name="None" '
    'line below to compiler_name=<my_imported_custom_function>'
)

dynamo_minifier_backend = functools.partial(
    compiler_fn,
    compiler_name="None",
)
opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)

with torch.cuda.amp.autocast(enabled=False):
    opt_mod(*args)
(base) root@6639b5aac59b:/workspace# 
desertfire commented 1 year ago

Please feel free to give it a try. cc @voznesenskym for any suggestion you may have.

thiagocrepaldi commented 1 year ago

PR with the fix https://github.com/pytorch/pytorch/pull/92013