pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.58k stars 350 forks source link

🐛 [Bug] Cannot convert simple torchscript containing two torch.nn.Upsample operations #1823

Closed gcuendet closed 1 year ago

gcuendet commented 1 year ago

Bug Description

Scripting a simple "network" containing two torch.nn.Upsample modules and trying to convert the resulting torchscript does not work.

To Reproduce

Steps to reproduce the behavior:

  1. Generate a torchscript of the Network below, with torch.jit.script.
  2. Try to convert to TensorRT with torch_tensorrt (I tried both in C++, on linux and in python)
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        # NOTE:
        # * Specifying a single float as scale_factor or the 2d tuple doesn't change the behavior
        self.upsample1 = torch.nn.Upsample(
            scale_factor=(2.0, 2.0), mode="bilinear", align_corners=False
        )
        self.upsample2 = torch.nn.Upsample(
            scale_factor=(2.0, 2.0), mode="bilinear", align_corners=False
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # NOTE:
        # * Using the same self.upsample() module doesn't work as well
        # * Computing and returning out doesn't work as well
        out1 = self.upsample1(X)
        out2 = self.upsample2(X)
        # out = out1 + out2
        return out1

Expected behavior

The conversions succeeds and a new valid torchscript is obtained.

Environment

I managed to reproduce the bug both when using pytorch 1.11 and torch-tensorRT 1.1.0 and using pytorch 1.13.1 and torch-tensorRT main.

Torch-TensorRT 1.1.0

When using torch-tensorRT 1.1.0, I get the following error:

DEBUG: [Torch-TensorRT] - Registering input/output torch::jit::Value for segmented graphs
terminate called after throwing an instance of 'c10::Error'
  what():  Expected Tensor but got Uninitialized

Exception raised from reportToTensorTypeError at /home/ubuntu/buildAgent/temp/buildTmp/conan_home/.conan/data/libtorch/1.11.0-5/cognex/stable/build/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/source_subfolder/aten/src/ATen/core/ivalue.cpp:908 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f66313e708b in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7f66313e2a5e in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #2: c10::IValue::reportToTensorTypeError() const + 0x64 (0x7f6634480064 in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libtorch_cpu.so)
frame #3: torch_tensorrt::core::partitioning::getSegmentsOutputByRunning(torch_tensorrt::core::partitioning::SegmentedBlock&, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x15a7 (0x7f6639fd8797 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #4: torch_tensorrt::core::partitioning::runShapeAnalysis(std::vector<torch_tensorrt::core::partitioning::SegmentedBlock, std::allocator<torch_tensorrt::core::partitioning::SegmentedBlock> >&, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x81 (0x7f6639fda581 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #5: torch_tensorrt::core::partitioning::Partition(torch::jit::Block*, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x19b (0x7f6639fe587b in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #6: torch_tensorrt::core::ConstructFallbackGraph(torch::jit::Module&, torch::jit::Block*, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >, torch_tensorrt::core::CompileSpec, std::map<torch::jit::Value*, c10::IValue, std::less<torch::jit::Value*>, std::allocator<std::pair<torch::jit::Value* const, c10::IValue> > >) + 0xff (0x7f663a00006f in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #7: torch_tensorrt::core::CompileGraph(torch::jit::Module const&, torch_tensorrt::core::CompileSpec) + 0x980 (0x7f663a002a00 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #8: torch_tensorrt::torchscript::compile(torch::jit::Module const&, torch_tensorrt::torchscript::CompileSpec) + 0x5b7 (0x7f6639e9e7f7 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #9: <unknown function> + 0x5ec8 (0x56183fd18ec8 in ./build/src/interpolate_tensorrt)
frame #10: __libc_start_main + 0xf3 (0x7f65fe02f083 in /lib/x86_64-linux-gnu/libc.so.6)
frame #11: <unknown function> + 0x58ee (0x56183fd188ee in ./build/src/interpolate_tensorrt)

That looked kind of similar to this issue and patching Torch-TensorRT with this PR makes the behavior exactly the same as in the second case (i.e. when using pytorch 1.13.1 and torch-tensorRT main).

Torch-TensorRT main (commit 861edd03a510c600146575836b02c993ac386b00)

When using torch-TensorRT main, the conversion just hangs for ever after

GRAPH: [Torch-TensorRT] - Torch-TensorRT.TorchScript Graph Lowering

Additional context

Interestingly, when using the tracing mechanism of pytorch to generate the torchscript, everything seems fine (I didn't check the results, but the conversion finishes properly). Also, when scripting with pytorch 1.9, everything works fine 🤯

The thing I noticed is that pytorch changed slightly the torch.nn.interpolate API and I am wondering if that could explain (at least partially) the problem:

See the attached .zip file containing a python file to generate the torchscript. upsample.zip

Let me know if you need more details to reproduce the problem. Thanks!

gs-olive commented 1 year ago

I've reproduced the error on main, and it is occurring on this line, where the operation torch::jit::EliminateExceptions does not complete. This is the source code for torch::jit::EliminateExceptions. https://github.com/pytorch/TensorRT/blob/a245b861d75fe0cb007eca5d23b3a992113b268b/core/lowering/lowering.cpp#L106

The graph at that point is shown below, and has a prim::RaiseException type within a prim::If

Graph ```python graph(%X.1 : Tensor): %2 : bool = prim::Constant[value=0]() %3 : float[] = prim::Constant[value=[2., 2.]]() %4 : str = prim::Constant[value="bilinear"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/upsampling.py:156:66 %5 : int = prim::Constant[value=5]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3956:22 %6 : int = prim::Constant[value=3]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3952:22 %7 : int = prim::Constant[value=4]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3949:76 %8 : int = prim::Constant[value=2]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:24 %9 : NoneType = prim::Constant() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3871:32 %10 : str = prim::Constant[value="builtins.ValueError"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3872:18 %11 : str = prim::Constant[value="Input and scale_factor must have the same number of spatial dimensions, but got input with spatial dimensions of {} and scale_factor of shape {}. Please provide input tensor in (N, C, d1, d2, ...,dK) format and scale_factor in (s1, s2, ...,sK) format."]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20 %12 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:34 %13 : str = prim::Constant[value="builtins.NotImplementedError"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:14 %14 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:34 %15 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8 %16 : int = prim::Constant[value=1]() %48 : int = prim::Constant[value=2]() %49 : str = prim::Constant[value="builtins.ValueError"]() %50 : float[] = prim::Constant[value=[2., 2.]]() %51 : str = prim::Constant[value="Input and scale_factor must have the same number of spatial dimensions, but got input with spatial dimensions of {} and scale_factor of shape {}. Please provide input tensor in (N, C, d1, d2, ...,dK) format and scale_factor in (s1, s2, ...,sK) format."]() %52 : int = prim::Constant[value=1]() %53 : NoneType = prim::Constant() %54 : int = prim::Constant[value=4]() %55 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() %56 : str = prim::Constant[value="bilinear"]() %57 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() %58 : int = prim::Constant[value=5]() %59 : str = prim::Constant[value="builtins.NotImplementedError"]() %60 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() %61 : int = prim::Constant[value=3]() %62 : bool = prim::Constant[value=0]() %63 : Tensor = prim::Uninitialized() %64 : int = aten::dim(%X.1) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:10 %dim.2 : int = aten::sub(%64, %48) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3880:10 %66 : bool = aten::ne(%48, %dim.2) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:15 = prim::If(%66) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:12 block0(): %67 : int[] = aten::size(%X.1) # :13:9 %68 : int[] = aten::slice(%67, %48, %53, %52) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:26 %69 : int[] = aten::list(%68) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:21 %70 : str = aten::format(%51, %69, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20 = prim::RaiseException(%70, %49) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3908:16 -> () block1(): -> () %71 : bool = aten::eq(%64, %54) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:7 %out1.1 : Tensor = prim::If(%71) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:4 block0(): %73 : Tensor = aten::upsample_bilinear2d(%X.1, %53, %62, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3983:15 -> (%73) block1(): %74 : bool = aten::eq(%64, %61) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:7 = prim::If(%74) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:4 block0(): = prim::RaiseException(%60, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:8 -> () block1(): -> () %75 : bool = aten::eq(%64, %58) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:7 = prim::If(%75) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:4 block0(): = prim::RaiseException(%57, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:8 -> () block1(): -> () %76 : str = aten::format(%55, %64, %56) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8 = prim::RaiseException(%76, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4006:4 -> (%63) = prim::If(%66) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3907:12 block0(): %77 : int[] = aten::size(%X.1) # :13:9 %78 : int[] = aten::slice(%77, %48, %53, %52) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:26 %79 : int[] = aten::list(%78) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:21 %80 : str = aten::format(%51, %79, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3909:20 = prim::RaiseException(%80, %49) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3908:16 -> () block1(): -> () %out2.1 : Tensor = prim::If(%71) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3979:4 block0(): %82 : Tensor = aten::upsample_bilinear2d(%X.1, %53, %62, %50) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3983:15 -> (%82) block1(): %83 : bool = aten::eq(%64, %61) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:7 = prim::If(%83) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3993:4 block0(): = prim::RaiseException(%60, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3994:8 -> () block1(): -> () %84 : bool = aten::eq(%64, %58) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:7 = prim::If(%84) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4003:4 block0(): = prim::RaiseException(%57, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4004:8 -> () block1(): -> () %85 : str = aten::format(%55, %64, %56) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4007:8 = prim::RaiseException(%85, %59) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:4006:4 -> (%63) %out.1 : Tensor = aten::add(%out1.1, %out2.1, %16) # unrelated.py:24:14 return (%out.1) ```

@bowang007 - this seems related to your work with exceptions and control flow, do you have any suggestions on this?

gcuendet commented 1 year ago

Thanks @gs-olive for having taken a look at this issue so quickly! Nice that you could reproduce it! Following up on that, here is another simple Network (though slightly less trivial than the one above).

Network definition ```python import torch import torch.nn as nn import torch.nn.functional as F import argparse class Block(nn.Module): def __init__(self, in_channel, out_channel): super(Block, self).__init__() self.conv = nn.Conv2d( in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False ) self.norm = nn.BatchNorm2d(out_channel) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d( kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False ) def forward(self, x): out = self.conv(x) out = self.norm(out) out = self.relu(out) out = self.maxpool(out) return out class Network(torch.nn.Module): def __init__(self, num_classes=2): super(Network, self).__init__() self.num_classes = num_classes self.block1 = Block(3, 32) self.block2 = Block(32, 64) self.upsample1 = nn.Upsample( scale_factor=2, mode="bilinear", align_corners=False ) self.upsample2 = nn.Upsample( scale_factor=2, mode="bilinear", align_corners=False ) self.conv = nn.Conv2d(64, num_classes, 1, bias=True) def forward(self, x): out = self.block1(x) out = self.block2(out) gclayer1 = self.upsample1(out) gclayer2 = self.upsample2(gclayer1) out = self.conv(gclayer2) return out ```

Interestingly, when converting a torchscript generated from that network using torch.jit.script on linux, the behaviour is the same as with the trivial network previously shared: using torch-tensorrt recent commit from main, it hangs (i.e. the operation torch::jit::EliminateExceptions does not complete). Nevertheless, when converting the torchscript on windows using the same torch-tensorrt recent commit from main, it works! Note that it works both when the torchscript is generated on windows or on linux.

Both graphs, as printed by Torch-TensorRT when calling compile are included in archive.zip, as well as the mini_net.py network definition/torchscript generation file. In summary, here is a diff between the graphs, as printed by Torch-TensorRT:

Graph diff ``` --- graph_lin.py 2023-04-19 13:28:59.000000000 +0200 +++ graph_win.py 2023-04-19 13:28:52.000000000 +0200 @@ -1,16 +1,13 @@ GRAPH: [Torch-TensorRT] - After freeze: graph(%self.1 : __torch__.___torch_mangle_0.Network, %x.1 : Tensor): - %224 : str[] = prim::Constant[value=["nearest", "area", "nearest-exact"]]() %115 : str = prim::Constant[value="bilinear"]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/modules/upsampling.py:153:66 %114 : float = prim::Constant[value=2.]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/modules/upsampling.py:153:47 %113 : int = prim::Constant[value=5]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3892:22 %112 : int = prim::Constant[value=3]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3888:22 %111 : int = prim::Constant[value=4]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3885:76 - %109 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3815:16 %108 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3930:34 %107 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3940:34 %106 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3943:8 - %101 : str = prim::Constant[value="AssertionError: "]() # :0:0 %223 : int[] = prim::Constant[value=[0, 0]]() %222 : int[] = prim::Constant[value=[2, 2]]() %221 : int[] = prim::Constant[value=[1, 1]]() @@ -38,15 +35,6 @@ %out1.1 : Tensor = aten::relu_(%out0.1) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:1440:17 %out0.2 : Tensor = aten::max_pool2d(%out1.1, %222, %222, %223, %221, %self.block1.norm.training) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:797:11 %117 : Tensor = prim::Uninitialized() # :0:0 - %118 : bool? = prim::Uninitialized() # :0:0 - %119 : bool = prim::Uninitialized() # :0:0 - %121 : bool = aten::__contains__(%224, %115) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3812:7 - %align_corners0.1 : bool? = prim::If(%121) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3812:4 - block0(): - = prim::RaiseException(%109, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3814:12 - -> (%118) - block1(): - -> (%self.block1.norm.training) %123 : int = aten::dim(%out0.2) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3822:10 %dim.2 : int = aten::sub(%123, %18) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3822:10 %scale_factors2.2 : float[] = prim::ListConstruct() @@ -57,15 +45,7 @@ %129 : bool = aten::eq(%123, %111) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3915:7 %gclayer1.1 : Tensor = prim::If(%129) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3915:4 block0(): - %132 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:15 - %align_corners6.1 : bool = prim::If(%132) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:8 - block0(): - %align_corners7.2 : bool = prim::unchecked_cast(%align_corners0.1) - -> (%align_corners7.2) - block1(): - = prim::RaiseException(%101, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:8 - -> (%119) - %135 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias, %align_corners6.1, %scale_factors2.2) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3919:15 + %135 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias, %self.block1.norm.training, %scale_factors2.2) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3919:15 -> (%135) block1(): %137 : bool = aten::eq(%123, %112) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3929:7 @@ -85,12 +65,6 @@ %143 : str = aten::format(%106, %123, %115) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3943:8 = prim::RaiseException(%143, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3942:4 -> (%117) - %align_corners0 : bool? = prim::If(%121) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3812:4 - block0(): - = prim::RaiseException(%109, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3814:12 - -> (%118) - block1(): - -> (%self.block1.norm.training) %167 : int = aten::dim(%gclayer1.1) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3822:10 %dim.1 : int = aten::sub(%167, %18) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3822:10 %scale_factors2.1 : float[] = prim::ListConstruct() @@ -101,15 +75,7 @@ %173 : bool = aten::eq(%167, %111) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3915:7 %gclayer2.1 : Tensor = prim::If(%173) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3915:4 block0(): - %176 : bool = aten::__isnot__(%align_corners0, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:15 - %align_corners6 : bool = prim::If(%176) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:8 - block0(): - %align_corners7.1 : bool = prim::unchecked_cast(%align_corners0) - -> (%align_corners7.1) - block1(): - = prim::RaiseException(%101, %self.block1.conv.bias) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3916:8 - -> (%119) - %179 : Tensor = aten::upsample_bilinear2d(%gclayer1.1, %self.block1.conv.bias, %align_corners6, %scale_factors2.1) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3919:15 + %179 : Tensor = aten::upsample_bilinear2d(%gclayer1.1, %self.block1.conv.bias, %self.block1.norm.training, %scale_factors2.1) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3919:15 -> (%179) block1(): %181 : bool = aten::eq(%167, %112) # /home/gcuendet/.pyenv/versions/pytorch1.11/lib/python3.8/site-packages/torch/nn/functional.py:3929:7 ```

Please let me know if I can provide more details to help solve this issue!

gs-olive commented 1 year ago

Thank you for the additional details - this is very helpful!

gcuendet commented 1 year ago

Hi @gs-olive I have been working on the idea you described in #1842 . See this commit. Of course, this initial implementation is overly specific and only solves the case of upsample_bilinear2d but together with your proposed fix of torch::jit::EliminateExceptions and an additional small modification of torch-tensorRT custom exceptions elimination pass, is seems that I am able to successfully convert the two networks linked in this issue (the simple upsample, as well as the slightly less trivial one).

I would appreciate some feedback on this. Is my approach going in the right direction, w.r.t. what you had in mind when describing the #1842 issue?

gcuendet commented 1 year ago

Also, what do you (or @bowang007 maybe?) think of the changes I made to TensorRT/core/lowering/passes/exception_elimination.cpp ? The rationale is that instead of catching only something like:

 = prim::If(%5958)
  block0():
    = prim::RaiseException(%45)
   -> ()
  block1():
   -> ()

or

 = prim::If(%5958)
  block0():
   -> ()
  block1():
    = prim::RaiseException(%45)
   -> ()

you could also catch more complex blocks, given that:

  1. the node is a prim::If, containing two blocks
  2. both blocks do not return anything
  3. the last node of at least one the block is a prim::RaiseException (instead of the first one currently)

What I've typically been observing in the scope of the Upsample investigation is something like:

 = prim::If(%5958)
  block0():
   -> ()
  block1():
    %191 : str = aten::format(%108, %171, %117)
   = prim::RaiseException(%191, %100)
   -> ()
gs-olive commented 1 year ago

Hi @gcuendet - thanks for the update! The commit linked here is definitely along the lines of what was intended for #1842. One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()?

Regarding the changes made to TensorRT/core/lowering/passes/exception_elimination.cpp, I think this rationale/edit idea is a good one. I have a few comments:

Also @narendasan for any comments on the proposed edits to TensorRT/core/lowering/passes/exception_elimination.cpp.

gcuendet commented 1 year ago

Thanks for the quick feedback! I am still fiddling with these changes and trying to make them work in more generic cases than the two overly simplified networks shared above, but regarding your questions:

  1. One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()?

At some point I had the impression that destroying the node was unnecessary, because some dead code removal pass would do it for you (I did observe that in some cases, but that might not be true in all cases). Maybe it could still be good to try to do it at that point, that way if the outputs are not properly replaced, this would fail. I'll check that.

  1. I wonder if auto arm1_last = arm1->nodes().rbegin() is always a prim::Return node, since every block in a prim::If must have a return to be valid.

I don't think auto arm1_last = arm1->nodes().rbegin() is always a prim::Return node, no. It's true that every block in a prim::If must have a return to be valid, but I think that the iterator returned by generic_graph_node_list<Node>::rbegin() (which is what you get when you call Block::nodes().rbegin()) starts on the last node just before the return node.

  1. Additionally, in the case where the prim::RaiseException is the last node in the block, I am unsure if we can be certain that the block returns nothing.

You are right, but we also check earlier that

if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
      // Make sure that the node doesn't actually produce any Value that are
      // used by other nodes
       return false;
}

So the hypothesis at that point is that none of the arms actually return a Value. In the example you point to (the prim::If denoted by %out2.1), both block0 and block1 have 1 output. And we should not even check that the last node is a prim::RaiseException but will return false earlier.

gcuendet commented 1 year ago

Hi! I have a small update regarding this work.

Upsample bilinear 2D exception elimination

New commit implementing the custom and specific upsample_bilinear2d exceptions removal is here.

A small note on that commit: even though the isUpsampleBlock method is very specific, I think that the copyAllNodes one could potentially be reused in the scope of #1842 . What I mean by that is that, once the prim::If node has been identified as well as which block of that node is doing the computation of interest, copying that whole block is implemented by copyAllNodes, including renaming the inputs of nodes that reuse outputs of previous nodes in that same block as well as replacing the usage of the prim::If outputs in the rest of the graph by the corresponding outputs of the nodes inside that block.

Regarding your comment 1. above,

One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()?

I think my previous answer was not completely accurate:

Exception elimination

New commit implementing changes to TensorRT/core/lowering/passes/exception_elimination.cpp is here.

I changed just slightly the implementation, most importantly to verify that the block of the prim::If node which is not raising an exception is also not computing anything. I initially thought that checking that the prim::If node had zero outputs (or similarly that each of the two blocks of the prim::If node had zero outputs) would be enough, but I saw cases where apparently that was not the case. (I am basically saying that the Values are not scoped in blocks, not sure if that's plausible or completely crazy).

Let me know if that's of interest to you, I'll be happy to open a PR!

gcuendet commented 1 year ago

Just for completeness: using Upsample directly in the network definition still does not seem to work, even with the changes proposed above and the fix in torch::jit::EliminateException. I am now observing the following error, in shape_analysis.cpp with the Network described in this comment:

GRAPH: [Torch-TensorRT] - Running shape analysis on block Segment Block @1:
    Target: Torch

    Graph: graph(%1 : bool,
      %4 : bool?):
  %self.block1.norm.training.15 : bool = prim::Constant[value=0]()
  %3 : str = prim::Constant[value="builtins.ValueError"]()
  %2 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]()
  %align_corners0.1 : bool? = prim::If(%1)
    block0():
       = prim::RaiseException(%2, %3)
      -> (%4)
    block1():
      -> (%self.block1.norm.training.15)
  return (%align_corners0.1)

terminate called after throwing an instance of 'torch_tensorrt::Error'
  what():  [Error thrown at core/partitioning/shape_analysis.cpp:187] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 474 produced from %474 : bool? = prim::Uninitialized() # :0:0 in lowering graph for mini graph input.

This seems linked to the check on the validity of the align_corners option (checking that the interpolating mode is one of linear, bilinear, bicubic or trilinear) and the fact that, when it's invalid, a bool? = prim::Uninitialized() is returned, which seems to cause the shape analysis to fail..? This is not simplified by the exception elimination pass, since the block actually returns something.

A workaround for this is to not call Upsample, but rather directly upsample_bilinear2d (or one of the other functions corresponding to the correct interpolating mode and dimensions).

I also include the lowered graph below (lowered with all the changes described above):

Lowered graph ```python INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor): [540/18335] %4 : float = prim::Constant[value=2.]() %14 : int[] = prim::Constant[value=[0, 0]]() %16 : int[] = prim::Constant[value=[1, 1]]() %17 : int = prim::Constant[value=1]() %18 : bool = prim::Constant[value=1]() %self.conv.bias : Float(2, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=0.01 * 7.4246 -8.4862 [ CUDAFloatType{2} ]]() %self.conv.weight : Float(2, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.conv.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.conv.bias.15 : NoneType = prim::Constant() %self.block1.norm.running_var : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.norm.running_mean : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.norm.training.15 : bool = prim::Constant[value=0]() %452 : float = prim::Constant[value=0.10000000000000001]() %453 : float = prim::Constant[value=1.0000000000000001e-05]() %455 : int[] = prim::Constant[value=[2, 2]]() %self.block2.conv.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block2.norm.running_var : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block2.norm.running_mean : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %460 : str[] = prim::Constant[value=["nearest", "area", "nearest-exact"]]() %461 : str = prim::Constant[value="bilinear"]() %462 : str = prim::Constant[value="builtins.ValueError"]() %463 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]() %464 : int = prim::Constant[value=2]() %579 : bool = prim::Constant[value=0]() %580 : int[] = prim::Constant[value=[0, 0]]() %581 : Tensor = aten::_convolution(%x.1, %self.block1.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %579, %580, %17, %579, %579, %579, %579) %out1.2 : Tensor = aten::batch_norm(%581, %self.block1.norm.running_var, %self.block1.norm.running_mean, %self.block1.norm.running_mean, %self.block1.norm.running_var, %self.block1.norm.training.15, %452, %453, %18) %467 : Tensor = aten::relu(%out1.2) %out.2 : Tensor = aten::max_pool2d(%467, %455, %455, %14, %16, %self.block1.norm.training.15) %582 : bool = prim::Constant[value=0]() %583 : int[] = prim::Constant[value=[0, 0]]() %584 : Tensor = aten::_convolution(%out.2, %self.block2.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %582, %583, %17, %582, %582, %582, %582) %out0.1 : Tensor = aten::batch_norm(%584, %self.block2.norm.running_var, %self.block2.norm.running_mean, %self.block2.norm.running_mean, %self.block2.norm.running_var, %self.block1.norm.training.15, %452, %453, %18) %471 : Tensor = aten::relu(%out0.1) %out0.2 : Tensor = aten::max_pool2d(%471, %455, %455, %14, %16, %self.block1.norm.training.15) %474 : bool? = prim::Uninitialized() # :0:0 %475 : bool = prim::Uninitialized() # :0:0 %476 : bool = aten::__contains__(%460, %461) %align_corners0.1 : bool? = prim::If(%476) block0(): = prim::RaiseException(%463, %462) -> (%474) block1(): -> (%self.block1.norm.training.15) %478 : int = aten::dim(%out0.2) %dim.2 : int = aten::sub(%478, %464) %scale_factors2.2 : float[] = prim::ListConstruct() = prim::Loop(%dim.2, %18) block0(%49 : int): %50 : float[] = aten::append(%scale_factors2.2, %4) -> (%18) %480 : int = prim::Constant[value=4]() %481 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() %482 : str = prim::Constant[value="AssertionError: "]() %485 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() %486 : int = prim::Constant[value=3]() %487 : str = prim::Constant[value="builtins.NotImplementedError"]() %488 : int = prim::Constant[value=5]() %489 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() %494 : bool = aten::eq(%478, %480) %571 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15) %align_corners6.2 : bool = prim::If(%571) block0(): %align_corners7.3 : bool = prim::unchecked_cast(%align_corners0.1) -> (%align_corners7.3) block1(): = prim::RaiseException(%482, %self.block1.conv.bias.15) -> (%475) %574 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias.15, %align_corners6.2, %scale_factors2.2) = prim::If(%494) block0(): = prim::If(%571) block0(): -> () block1(): = prim::RaiseException(%482, %self.block1.conv.bias.15) -> () -> () block1(): %500 : bool = aten::eq(%478, %486) = prim::If(%500) block0(): = prim::RaiseException(%485, %487) -> () block1(): -> () %501 : bool = aten::eq(%478, %488) = prim::If(%501) block0(): = prim::RaiseException(%489, %487) -> () block1(): -> () %502 : str = aten::format(%481, %478, %461) = prim::RaiseException(%502, %487) -> () %align_corners0 : bool? = prim::If(%self.block1.norm.training.15) block0(): = prim::RaiseException(%463, %462) -> (%474) block1(): -> (%self.block1.norm.training.15) %504 : int = aten::dim(%574) %dim.1 : int = aten::sub(%504, %464) %scale_factors2.1 : float[] = prim::ListConstruct() = prim::Loop(%dim.1, %18) block0(%64 : int): %65 : float[] = aten::append(%scale_factors2.1, %4) -> (%18) %516 : bool = aten::eq(%504, %480) %575 : bool = aten::__isnot__(%align_corners0, %self.block1.conv.bias.15) %align_corners6.6 : bool = prim::If(%575) block0(): %align_corners7.7 : bool = prim::unchecked_cast(%align_corners0) -> (%align_corners7.7) block1(): = prim::RaiseException(%482, %self.block1.conv.bias.15) -> (%475) %578 : Tensor = aten::upsample_bilinear2d(%574, %self.block1.conv.bias.15, %align_corners6.6, %scale_factors2.1) = prim::If(%516) block0(): = prim::If(%575) block0(): -> () block1(): = prim::RaiseException(%482, %self.block1.conv.bias.15) -> () -> () block1(): %522 : bool = aten::eq(%504, %486) = prim::If(%522) block0(): = prim::RaiseException(%485, %487) -> () block1(): -> () %523 : bool = aten::eq(%504, %488) = prim::If(%523) block0(): = prim::RaiseException(%489, %487) -> () block1(): -> () %524 : str = aten::format(%481, %504, %461) = prim::RaiseException(%524, %487) -> () %585 : bool = prim::Constant[value=0]() %586 : int[] = prim::Constant[value=[0, 0]]() %587 : Tensor = aten::_convolution(%578, %self.conv.weight, %self.conv.bias, %16, %14, %16, %585, %586, %17, %585, %585, %585, %585) return (%587) ```
gs-olive commented 1 year ago

Hi @gcuendet - thank you very much for all of the work and detailed answers on this topic. I made a few comments on the upsample_bilinear2d exceptions removal and the new changes to TensorRT/core/lowering/passes/exception_elimination.cpp. I think both of these updates look good and would be welcome additions via a PR, though for the upsample_bilinear2d exceptions removal, we would also need to add some testing since it is a new feature.

@narendasan - do you have any input on the proposed changes to TensorRT/core/lowering/passes/exception_elimination.cpp and the custom upsample_bilinear2d exceptions removal pass?

Regarding the Could not find torch::jit::Value* issue, we are tracking + investigating this issue with @bowang007 across multiple reports, including #1834 and #1815. A common thread among all of these seems to be the presence of nested "If" blocks elsewhere in the code, though it's not yet clear if this is the root cause of the issue.

bowang007 commented 1 year ago

Hey @gcuendet can you try this PR: https://github.com/pytorch/TensorRT/pull/1933 I reproduced your bug for uninitialized error and I think this PR might help with it.

gcuendet commented 1 year ago

Thanks @bowang007 . Did you test on the small network described above? Is it working for you? The conversion is not working for me. I get the following error:

GRAPH: [Torch-TensorRT] - Running shape analysis on block Segment Block @3:
    Target: Torch

    Graph: graph(%out0.2 : Tensor,
      %align_corners0.1 : bool?):
  %self.block1.conv.bias.15 : NoneType = prim::Constant()
  %10 : int = prim::Constant[value=4]()
  %8 : float = prim::Constant[value=2.]()
  %5 : bool = prim::Constant[value=1]()
  %3 : int = prim::Constant[value=2]()
  %0 : int = aten::dim(%out0.2)
  %dim.2 : int = aten::sub(%0, %3)
  %scale_factors2.2 : float[] = prim::ListConstruct()
   = prim::Loop(%dim.2, %5)
    block0(%6 : int):
      %7 : float[] = aten::append(%scale_factors2.2, %8)
      -> (%5)
  %9 : bool = aten::eq(%0, %10)
  %11 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15)
  return (%scale_factors2.2, %0, %9, %11)

terminate called after throwing an instance of 'torch_tensorrt::Error'
  what():  [Error thrown at core/partitioning/shape_analysis.cpp:212] Expected to find type bool? for value align_corners0.1 but get nothing.

For reference the lowered graph is the following.

Lowered graph ```python INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor): [588/19311] %4 : float = prim::Constant[value=2.]() %14 : int[] = prim::Constant[value=[0, 0]]() %16 : int[] = prim::Constant[value=[1, 1]]() %17 : int = prim::Constant[value=1]() %18 : bool = prim::Constant[value=1]() %self.conv.bias : Float(2, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=0.01 * 7.4246 -8.4862 [ CUDAFloatType{2} ]]() %self.conv.weight : Float(2, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.conv.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.conv.bias.15 : NoneType = prim::Constant() %self.block1.norm.running_var : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.norm.running_mean : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block1.norm.training.15 : bool = prim::Constant[value=0]() %452 : float = prim::Constant[value=0.10000000000000001]() %453 : float = prim::Constant[value=1.0000000000000001e-05]() %455 : int[] = prim::Constant[value=[2, 2]]() %self.block2.conv.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block2.norm.running_var : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %self.block2.norm.running_mean : Float(64, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() %460 : str[] = prim::Constant[value=["nearest", "area", "nearest-exact"]]() %461 : str = prim::Constant[value="bilinear"]() %462 : str = prim::Constant[value="builtins.ValueError"]() %463 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]() %464 : int = prim::Constant[value=2]() %579 : bool = prim::Constant[value=0]() %580 : int[] = prim::Constant[value=[0, 0]]() %581 : Tensor = aten::_convolution(%x.1, %self.block1.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %579, %580, %17, %579, %579, %579, %579) %out1.2 : Tensor = aten::batch_norm(%581, %self.block1.norm.running_var, %self.block1.norm.running_mean, %self.block1.norm.running_mean, %self.block1.norm.running_var, %self.block1.norm.training.15, %452, %453, %18) %467 : Tensor = aten::relu(%out1.2) %out.2 : Tensor = aten::max_pool2d(%467, %455, %455, %14, %16, %self.block1.norm.training.15) %582 : bool = prim::Constant[value=0]() %583 : int[] = prim::Constant[value=[0, 0]]() %584 : Tensor = aten::_convolution(%out.2, %self.block2.conv.weight, %self.block1.conv.bias.15, %16, %16, %16, %582, %583, %17, %582, %582, %582, %582) %out0.1 : Tensor = aten::batch_norm(%584, %self.block2.norm.running_var, %self.block2.norm.running_mean, %self.block2.norm.running_mean, %self.block2.norm.running_var, %self.block1.norm.training.15, %452, %453, %18) %471 : Tensor = aten::relu(%out0.1) %out0.2 : Tensor = aten::max_pool2d(%471, %455, %455, %14, %16, %self.block1.norm.training.15) %474 : bool? = prim::Uninitialized() # :0:0 %475 : bool = prim::Uninitialized() # :0:0 %476 : bool = aten::__contains__(%460, %461) %align_corners0.1 : bool? = prim::If(%476) block0(): = prim::RaiseException(%463, %462) -> (%474) block1(): -> (%self.block1.norm.training.15) %478 : int = aten::dim(%out0.2) %dim.2 : int = aten::sub(%478, %464) %scale_factors2.2 : float[] = prim::ListConstruct() = prim::Loop(%dim.2, %18) block0(%49 : int): %50 : float[] = aten::append(%scale_factors2.2, %4) -> (%18) %480 : int = prim::Constant[value=4]() %482 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() %483 : int = prim::Constant[value=5]() %485 : str = prim::Constant[value="AssertionError: "]() %486 : str = prim::Constant[value="builtins.NotImplementedError"]() %487 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]() %488 : int = prim::Constant[value=3]() %489 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]() %494 : bool = aten::eq(%478, %480) %571 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15) %align_corners6.2 : bool = prim::If(%571) block0(): %align_corners7.3 : bool = prim::unchecked_cast(%align_corners0.1) -> (%align_corners7.3) block1(): = prim::RaiseException(%485, %self.block1.conv.bias.15) -> (%475) %574 : Tensor = aten::upsample_bilinear2d(%out0.2, %self.block1.conv.bias.15, %align_corners6.2, %scale_factors2.2) = prim::If(%494) block0(): = prim::If(%571) block0(): -> () block1(): = prim::RaiseException(%485, %self.block1.conv.bias.15) -> () -> () block1(): %500 : bool = aten::eq(%478, %488) = prim::If(%500) block0(): = prim::RaiseException(%487, %486) -> () block1(): -> () %501 : bool = aten::eq(%478, %483) = prim::If(%501) block0(): = prim::RaiseException(%489, %486) -> () block1(): -> () %502 : str = aten::format(%482, %478, %461) = prim::RaiseException(%502, %486) -> () %align_corners0 : bool? = prim::If(%self.block1.norm.training.15) block0(): = prim::RaiseException(%463, %462) -> (%474) block1(): -> (%self.block1.norm.training.15) %504 : int = aten::dim(%574) %dim.1 : int = aten::sub(%504, %464) %scale_factors2.1 : float[] = prim::ListConstruct() = prim::Loop(%dim.1, %18) block0(%64 : int): %65 : float[] = aten::append(%scale_factors2.1, %4) -> (%18) %516 : bool = aten::eq(%504, %480) %575 : bool = aten::__isnot__(%align_corners0, %self.block1.conv.bias.15) %align_corners6.6 : bool = prim::If(%575) block0(): %align_corners7.7 : bool = prim::unchecked_cast(%align_corners0) -> (%align_corners7.7) block1(): = prim::RaiseException(%485, %self.block1.conv.bias.15) -> (%475) %578 : Tensor = aten::upsample_bilinear2d(%574, %self.block1.conv.bias.15, %align_corners6.6, %scale_factors2.1) = prim::If(%516) block0(): = prim::If(%575) block0(): -> () block1(): = prim::RaiseException(%485, %self.block1.conv.bias.15) -> () -> () block1(): %522 : bool = aten::eq(%504, %488) = prim::If(%522) block0(): = prim::RaiseException(%487, %486) -> () block1(): -> () %523 : bool = aten::eq(%504, %483) = prim::If(%523) block0(): = prim::RaiseException(%489, %486) -> () block1(): -> () %524 : str = aten::format(%482, %504, %461) = prim::RaiseException(%524, %486) -> () %585 : bool = prim::Constant[value=0]() %586 : int[] = prim::Constant[value=[0, 0]]() %587 : Tensor = aten::_convolution(%578, %self.conv.weight, %self.conv.bias, %16, %14, %16, %585, %586, %17, %585, %585, %585, %585) return (%587) ```
bowang007 commented 1 year ago

Hey @gcuendet, I was using the model provided:

import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch_tensorrt

class Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(
            in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.norm = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(
            kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False
        )

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        out = self.relu(out)
        out = self.maxpool(out)
        return out

class Network(torch.nn.Module):
    def __init__(self, num_classes=2):
        super(Network, self).__init__()
        self.num_classes = num_classes

        self.block1 = Block(3, 32)
        self.block2 = Block(32, 64)

        self.upsample1 = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=False
        )

        self.upsample2 = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=False
        )

        self.conv = nn.Conv2d(64, num_classes, 1, bias=True)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)

        gclayer1 = self.upsample1(out)
        gclayer2 = self.upsample2(gclayer1)

        out = self.conv(gclayer2)
        return out

input = torch.randn([3, 3,224, 224]).cuda()

model = Network()
model = model.eval().cuda()
model = torch.jit.script(model)
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)

compile_settings = {
    "inputs": [
        torch_tensorrt.Input([2, 2], dtype=torch.int32),
    ],
    "min_block_size": 1,
    # "truncate_long_and_double": True,
    # "enabled_precisions": {torch.int64},
    # "torch_executed_ops": ['aten::conv2d']
}

trt_mod = torch_tensorrt.ts.compile(model, **compile_settings)
output = trt_mod(*input)

Any details that I might miss?

github-actions[bot] commented 1 year ago

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days