pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.94k stars 22.62k forks source link

[FakeTensor] Error: SymIntArrayRef expected to contain only concrete integers in UpSample2D block #134303

Open KossBoii opened 2 months ago

KossBoii commented 2 months ago

🐛 Describe the bug

I was trying to export the DiffSTE model from this repo to onnx using dynamo_export. The model is able to be trained and run inference normally. However, when I export the model, it complains that RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2305: SymIntArrayRef expected to contain only concrete integers for the UpBlock2D, which is calling the Upsample2D subsequently. I also attach the Dynmo Export report below.

        # diffusers/models/unets/unet_2d_blocks.py", line 2685
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states, upsample_size)        <--------------- this part

        return hidden_states
       # diffusers/models/upsampling.py", line 171, in forward
       hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") <--------------- this part

Export Code


def onnx_export(
    model,
    model_args: tuple,
    output_path: Path,
    ordered_input_names,
    output_names,
    dynamic_axes,
    opset,
    use_external_data_format=False,
):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with torch.no_grad():
        model.eval()
        export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
        if isinstance(model_args, tuple):
            torch.onnx.dynamo_export(
                model, 
                *model_args,
                export_options=export_options
            ).save(output_path.as_posix())
        else:
            torch.onnx.dynamo_export(
                model, 
                model_args,
                export_options=export_options
            ).save(output_path.as_posix())

# UNET
unet_in_channels = model.unet.config.in_channels
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
    model.unet,
    model_args=(
        torch.randn(2, unet_in_channels, img_size // 8, img_size // 8).to(device=device, dtype=dtype),
        torch.tensor(100, dtype=torch.float32),
        torch.randn(2, num_tokens, config.model.unet.cross_attention_dim["text"]).to(device=device, dtype=dtype),
        torch.randn(2, num_tokens, config.model.unet.cross_attention_dim["char"]).to(device=device, dtype=dtype),  
    ),
    output_path=unet_path,
    ordered_input_names=["sample", "timestep", "text_hidden_states", "char_hidden_states"],
    output_names=["out_sample"],  # has to be different from "sample" for correct tracing
    dynamic_axes={
        "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
        "timestep": {},
        "text_hidden_states": {0: "batch", 1: "sequence"},
        "char_hidden_states": {0: "batch", 1: "sequence"},
    },
    opset=opset,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
    unet,
    unet_model_path,
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location="weights.pb",
    convert_attribute=False,
)
del model.unet

Some stack traces

E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761] failed while attempting to run meta for aten._unsafe_index.Tensor
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761] Traceback (most recent call last):
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     r = func(*args, **kwargs)
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     return self_._op(*args, **kwargs)
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_meta_registrations.py", line 3019, in meta_index_Tensor
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     indices = list(refs._maybe_broadcast(*indices))
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 439, in _maybe_broadcast
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     return tuple(__maybe_broadcast(x, common_shape) for x in args)
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 439, in <genexpr>
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     return tuple(__maybe_broadcast(x, common_shape) for x in args)
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]   File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 431, in __maybe_broadcast
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761]     return x.expand(common_shape)
E0822 23:10:34.280457 139634357630784 torch/_subclasses/fake_tensor.py:1761] RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2305: SymIntArrayRef expected to contain only concrete integers
Traceback (most recent call last):
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1504, in dynamo_export
    return Exporter(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 233, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 243, in pre_export_passes
    return exporter.common_pre_export_passes(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1548, in common_pre_export_passes
    module = passes.Functionalize(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 152, in wrapper
    ctx.log_and_raise_if_error(diag)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 369, in log_and_raise_if_error
    raise diagnostic.source_exception
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 136, in wrapper
    return_values = fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/_pass.py", line 278, in run
    module = self._run(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/functionalization.py", line 124, in _run
    graph_module = proxy_tensor.make_fx(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1421, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1367, in trace
    return self._trace_inner(f, *args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1354, in _trace_inner
    t = dispatch_trace(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 642, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 660, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/functionalization.py", line 87, in wrapped
    out = function(*inputs_functional)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/_utils.py", line 31, in wrapped
    return torch.fx.Interpreter(graph_module).run(*args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 705, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 467, in proxy_call
    out = func(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1153, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_meta_registrations.py", line 3019, in meta_index_Tensor
    indices = list(refs._maybe_broadcast(*indices))
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 439, in _maybe_broadcast
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 439, in <genexpr>
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py", line 431, in __maybe_broadcast
    return x.expand(common_shape)
RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2305: SymIntArrayRef expected to contain only concrete integers

While executing %_unsafe_index : [num_users=1] = call_function[target=torch.ops.aten._unsafe_index.Tensor](args = (%div_41, [None, None, %unsqueeze_29, %_to_copy_2]), kwargs = {})
Original traceback:
  File "/home/ltruong/synthetic_license_plate/models/diffste/unet_2d_multicondition.py", line 387, in forward
    sample = upsample_block(
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2685, in forward
    hidden_states = upsampler(hidden_states, upsample_size)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/diffusers/models/upsampling.py", line 171, in forward
    hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @penguinwu @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @eellison

Versions

[pip3] numpy==1.24.4
[pip3] onnx==1.15.0
[pip3] onnx-graphsurgeon==0.5.2
[pip3] onnxoptimizer==0.3.13
[pip3] onnxruntime==1.19.0
[pip3] onnxscript==0.1.0.dev20240822
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.4.0+cu124
[pip3] torch-fidelity==0.3.0
[pip3] torchaudio==2.4.0+cu124
[pip3] torchmetrics==1.4.0.post0
[pip3] torchtext==0.18.0
[pip3] torchvision==0.19.0+cu124
[pip3] triton==3.0.0
[pip3] tritonclient==2.48.0
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] pytorch-lightning         2.4.0                    pypi_0    pypi
[conda] torch                     2.4.0+cu124              pypi_0    pypi
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torchaudio                2.4.0+cu124              pypi_0    pypi
[conda] torchmetrics              1.4.0.post0              pypi_0    pypi
[conda] torchtext                 0.18.0                   pypi_0    pypi
[conda] torchvision               0.19.0+cu124             pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
[conda] tritonclient              2.48.0                   pypi_0    pypi
KossBoii commented 2 months ago

report_dynamo_export.sarif file content

{
 "runs":[
  {
   "tool":{
    "driver":{
     "name":"torch.onnx.dynamo_export",
     "contents":[
      "localizedData",
      "nonLocalizedData"
     ],
     "language":"en-US",
     "rules":[
      {
       "id":"FXE0010",
       "fullDescription":{
        "text":"FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
        "markdown":"This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n"
       },
       "name":"fx-pass",
       "shortDescription":{
        "text":"FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
       }
      }
     ],
     "version":"2.4.0+cu124"
    }
   },
   "language":"en-US",
   "newlineSequences":[
    "\r\n",
    "\n"
   ],
   "results":[
    {
     "message":{
      "markdown":"Running Decompose pass. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature Transform.run\n- self: <class 'torch.onnx._internal.fx.passes.decomp.Decompose'>\n- args: Tuple[length=4](\nTensor(f32[2, 9, 28, 28]),\nTensor(f32[]),\nTensor(f32[2, 77, 768]),\nTensor(f32[2, 77, 32]),\n)\nFor detailed logging of graph modifications by this pass, either set `DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable `TORCH_LOGS='onnx_diagnostics'`.\n## Return values\ntorch.fx.GraphModule(<lambda>)",
      "text":"Running Decompose pass. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"Transform.run"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/_pass.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":243
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0010",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Running Functionalize pass. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature Transform.run\n- self: <class 'torch.onnx._internal.fx.passes.functionalization.Functionalize'>\n- args: Tuple[length=4](\nTensor(f32[2, 9, 28, 28]),\nTensor(f32[]),\nTensor(f32[2, 77, 768]),\nTensor(f32[2, 77, 32]),\n)\nFor detailed logging of graph modifications by this pass, either set `DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable `TORCH_LOGS='onnx_diagnostics'`.\n## Exception log\n```\nTraceback (most recent call last):\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py\", line 136, in wrapper\n    return_values = fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/_pass.py\", line 278, in run\n    module = self._run(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/functionalization.py\", line 124, in _run\n    graph_module = proxy_tensor.make_fx(\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 1421, in wrapped\n    return make_fx_tracer.trace(f, *args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 1367, in trace\n    return self._trace_inner(f, *args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 1354, in _trace_inner\n    t = dispatch_trace(\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_compile.py\", line 31, in inner\n    return disable_fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py\", line 600, in _fn\n    return fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 642, in dispatch_trace\n    graph = tracer.trace(root, concrete_args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py\", line 600, in _fn\n    return fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py\", line 822, in trace\n    (self.create_arg(fn(*args)),),\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 660, in wrapped\n    out = f(*tensors)\n\n  File \"<string>\", line 1, in <lambda>\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/functionalization.py\", line 87, in wrapped\n    out = function(*inputs_functional)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/passes/_utils.py\", line 31, in wrapped\n    return torch.fx.Interpreter(graph_module).run(*args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py\", line 146, in run\n    self.env[node] = self.run_node(node)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py\", line 203, in run_node\n    return getattr(self, n.op)(n.target, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/interpreter.py\", line 275, in call_function\n    return target(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py\", line 667, in __call__\n    return self_._op(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 705, in __torch_function__\n    return func(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py\", line 667, in __call__\n    return self_._op(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/utils/_stats.py\", line 21, in wrapper\n    return fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 755, in __torch_dispatch__\n    return self.inner_torch_dispatch(func, types, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 790, in inner_torch_dispatch\n    return proxy_call(self, func, self.pre_dispatch, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py\", line 467, in proxy_call\n    out = func(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py\", line 667, in __call__\n    return self_._op(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/utils/_stats.py\", line 21, in wrapper\n    return fn(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py\", line 1061, in __torch_dispatch__\n    return self.dispatch(func, types, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py\", line 1450, in dispatch\n    return self._cached_dispatch_impl(func, types, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py\", line 1153, in _cached_dispatch_impl\n    output = self._dispatch_impl(func, types, args, kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py\", line 1757, in _dispatch_impl\n    r = func(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_ops.py\", line 667, in __call__\n    return self_._op(*args, **kwargs)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_meta_registrations.py\", line 3019, in meta_index_Tensor\n    indices = list(refs._maybe_broadcast(*indices))\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py\", line 439, in _maybe_broadcast\n    return tuple(__maybe_broadcast(x, common_shape) for x in args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py\", line 439, in <genexpr>\n    return tuple(__maybe_broadcast(x, common_shape) for x in args)\n\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/_refs/__init__.py\", line 431, in __maybe_broadcast\n    return x.expand(common_shape)\n\nRuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2305: SymIntArrayRef expected to contain only concrete integers\n\nWhile executing %_unsafe_index : [num_users=1] = call_function[target=torch.ops.aten._unsafe_index.Tensor](args = (%div_41, [None, None, %unsqueeze_29, %_to_copy_2]), kwargs = {})\nOriginal traceback:\n  File \"/home/ltruong/synthetic_license_plate/models/diffste/unet_2d_multicondition.py\", line 387, in forward\n    sample = upsample_block(\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/diffusers/models/unets/unet_2d_blocks.py\", line 2685, in forward\n    hidden_states = upsampler(hidden_states, upsample_size)\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/diffusers/models/upsampling.py\", line 171, in forward\n    hidden_states = F.interpolate(hidden_states, size=output_size, mode=\"nearest\")\n\n\n```",
      "text":"Running Functionalize pass. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"fail",
     "level":"error",
     "locations":[
      {
       "message":{
        "text":"Transform.run"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/home/ltruong/miniconda3/envs/synth/lib/python3.8/site-packages/torch/onnx/_internal/fx/_pass.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":243
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0010",
     "stacks":[]
    }
   ]
  }
 ],
 "version":"2.1.0",
 "schemaUri":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
}
KossBoii commented 2 months ago

Seems like the issue also happens in this #124884

tugsbayasgalan commented 2 months ago

What is the model type here? Is it torch.export.ExportedProgram?

KossBoii commented 2 months ago

@tugsbayasgalan Yes, it is. Any update?

justinchuby commented 2 months ago

Please test with torch.onnx.export(…, dynamo=True) and let us know