pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

🐛 [Bug] `aten.scatter.value` converter error #2865

Closed HolyWu closed 5 months ago

HolyWu commented 6 months ago

Bug Description

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_index_ : torch.Tensor [num_users=1] = placeholder[target=L_index_]
    %scatter_ : [num_users=1] = call_method[target=scatter_](args = (%l_x_, 0, %l_index_, 2), kwargs = {})
    return (scatter_,)
DEBUG:torch_tensorrt.dynamo.lowering._repair_input_aliasing:Inserted auxiliary clone nodes for placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_index_ : torch.Tensor [num_users=1] = placeholder[target=L_index_]
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_index_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %scatter_ : [num_users=1] = call_method[target=scatter_](args = (%clone_default, 0, %clone_default_1, 2), kwargs = {})
    return (scatter_,)
DEBUG:torch_tensorrt.dynamo.lowering._remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_index_ : torch.Tensor [num_users=1] = placeholder[target=L_index_]
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_index_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %scatter_ : [num_users=1] = call_method[target=scatter_](args = (%clone_default, 0, %clone_default_1, 2), kwargs = {})
    return (scatter_,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})
    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%clone_1, 0, %clone, 2), kwargs = {})
    return (scatter,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone_1 from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg1_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%arg0_1, 0, %arg1_1, 2), kwargs = {})
    return (scatter,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%arg0_1, 0, %arg1_1, 2), kwargs = {})
    return (scatter,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.scatter.value + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.scatter.value + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(3, 5), (1, 2)]
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%arg0_1, 0, %arg1_1, 2), kwargs = {})
    return scatter
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 141, GPU 1009 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1765, GPU +310, now: CPU 2041, GPU 1319 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Graph to be compiled to TensorRT: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%arg0_1, 0, %arg1_1, 2), kwargs = {})
    return scatter
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[3, 5], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg1_1 [shape=[1, 2], dtype=DataType.INT64]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node scatter (kind: aten.scatter.value, args: ('arg0_1 <tensorrt.ITensor [shape=(3, 5), dtype=DataType.HALF]>', 0, 'arg1_1 <tensorrt.ITensor [shape=(1, 2), dtype=DataType.INT64]>', 2))
[TRT] [E] Could not implicitly convert NumPy data type: f64 to TensorRT.
Traceback (most recent call last):
  File "/home/holy/test.py", line 32, in <module>
    print(optimized_model(*inputs))
  File "/home/holy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 421, in _fn
    return fn(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1077, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 918, in _convert_frame
    result = inner_convert(
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 456, in _convert_frame_assert
    return _compile(
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 82, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 128, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 799, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 618, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 564, in transform
    tracer.run()
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2244, in run
    super().run()
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 886, in run
    while self.step():
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 801, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2435, in RETURN_VALUE
    self._return(inst)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2420, in _return
    self.output.compile_subgraph(
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1095, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1312, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1403, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1384, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/holy/.local/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/__init__.py", line 1786, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 43, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 51, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 102, in _pretraced_backend
    trt_compiled = compile_module(
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 402, in compile_module
    trt_module = convert_module(
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 106, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 87, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 308, in run
    super().run()
  File "/home/holy/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 347, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/home/holy/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 443, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 519, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 710, in aten_ops_scatter
    return impl.select.scatter(
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/select.py", line 414, in scatter
    src_tensor = get_trt_tensor(
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 378, in get_trt_tensor
    return create_constant(ctx, input_val, name, dtype)
  File "/home/holy/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 341, in create_constant
    constant = ctx.net.add_constant(
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt_bindings.tensorrt.INetworkDefinition, shape: tensorrt_bindings.tensorrt.Dims, weights: tensorrt_bindings.tensorrt.Weights) -> tensorrt_bindings.tensorrt.IConstantLayer

Invoked with: <tensorrt_bindings.tensorrt.INetworkDefinition object at 0x7f35e0a247b0>, (1, 2), array([[2., 2.]])

While executing %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%arg0_1, 0, %arg1_1, 2), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f35e0baef30>: ((3, 5), torch.float16, False, (5, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f35e0a25870>: ((1, 2), torch.int64, False, (2, 1), torch.contiguous_format, False, {})}})
Original traceback:
None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, index):
        return x.scatter_(0, index, 2)

device = torch.device("cuda", 0)
model = MyModule().eval().to(device).half()

inputs = [
    torch.zeros((3, 5), dtype=torch.half, device=device),
    torch.tensor([[0, 1]], device=device),
]

optimized_model = torch_tensorrt.compile(
    model,
    ir="torch_compile",
    inputs=inputs,
    enabled_precisions={torch.half},
    debug=True,
    min_block_size=1,
    device=device,
)

print(optimized_model(*inputs))

Environment

apbose commented 5 months ago

Hi @HolyWu can you run the above with

optimized_model = torch_tensorrt.compile(
    model,
    ir="torch_compile",
    inputs=inputs,
    enabled_precisions={torch.half},
    debug=True,
    truncate_double=True,
    min_block_size=1,
    device=device,
)

It should pass then. The value 2 is taken as float64 which is not supported by TRT.

HolyWu commented 5 months ago

Hi @apbose. Adding truncate_double=True does resolve it. Just wodering why an integer value (2) would be taken as float64 (2.0) in the first place and then requires the user to specify truncate_double=True. Couldn't the converter force the argument to be an integer?

apbose commented 5 months ago

We generally do not keep the default as True since we would not want the model to be modified by itself by torchTRT when the user does not explicitly specify so.