pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.61k stars 351 forks source link

šŸ› [Bug] RuntimeError: Unhandled FakeTensor Device Propagation for aten.where.self, found two different devices cuda:0, cpu #3280

Closed qingchu123 closed 2 weeks ago

qingchu123 commented 3 weeks ago

Bug Description

when i use torch.compile with tensorrt as backend for a DFNs models like that:

model = AutoModel.from_pretrained("~/models/hf_pfnmodels",device_map="cuda:0")
processor = AutoProcessor.from_pretrained("~/models/hf_pfnmodels")
optimized_model = torch.compile(
    model,
    backend="tensorrt"
)

and i invoke processor function like this

inputs = processor(
    text=["aaa"],
    images=Image.open(
        "/home/fuyx/master-degree/backend/muticlip/test/img/a.jpg"
    ),
    padding="max_length",
    return_tensors="pt",
).to(device)

but when i use model to get output _ = optimized_model(**inputs),it raise the error:

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

WARNING:torch_tensorrt.dynamo.backend.backends:TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead.
Traceback (most recent call last):
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 90, in _pretraced_backend
    gm = aot_export_joint_simple(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1235, in aot_export_joint_simple
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1350, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 95, in aot_dispatch_export
    graph, _, _ = aot_dispatch_base_graph(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 138, in aot_dispatch_base_graph
    fw_module = _create_graph(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 46, in _create_graph
    fx_g = make_fx(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1421, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1367, in trace
    return self._trace_inner(f, *args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1354, in _trace_inner
    t = dispatch_trace(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 642, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 660, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 388, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 72, in inner_fn
    outs = fn(*args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 178, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 316, in __call__
    raise e
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 569, in call_module
    return forward(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.1", line 739, in forward
    masked_fill_ = mask.masked_fill_(lt, 0);  lt = None
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 705, in __torch_function__
    return func(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_device.py", line 79, in __torch_function__
    return func(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/functional_tensor.py", line 468, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 329, in proxy_call
    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1441, in maybe_handle_decomp
    return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 5600, in masked_fill
    r = torch.where(mask, value, a)  # type: ignore[arg-type]
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 340, in proxy_call
    r = func.decompose(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_ops.py", line 704, in decompose
    return self._op_dk(dk, *args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 467, in proxy_call
    out = func(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1765, in _dispatch_impl
    self.wrap_meta_outputs_with_default_device_logic(
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1875, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_pytree.py", line 948, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_pytree.py", line 787, in unflatten
    leaves = list(leaves)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1853, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 775, in _find_common_device
    merge_devices(arg)
  File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 770, in merge_devices
    raise RuntimeError(
RuntimeError: Unhandled FakeTensor Device Propagation for aten.where.self, found two different devices cuda:0, cpu

so how should i solve this problem

To Reproduce

Steps to reproduce the behavior:

1.compile the model like motioned 2.process the input

  1. invoke the model

Expected behavior

finish the forward process successfully

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context