Closed dav-ell closed 1 year ago
Can you provide a repro script or share the debug logs? Usually we print out what output tensor shapes are from ops, which can help show what is going on.
Attempting to make a repro script but stuck on a different issue now...
import torch_tensorrt
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
print(torch_tensorrt.__version__)
import torch
import torch.nn as nn
from torch.nn import functional as F
print(torch.__version__)
import math
from typing import List
class Conv2dStaticSamePadding(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
bias=bias, groups=groups)
self.stride = self.conv.stride
self.kernel_size = self.conv.kernel_size
self.dilation = self.conv.dilation
if isinstance(self.stride, int):
self.stride = [self.stride] * 2
elif len(self.stride) == 1:
self.stride = [self.stride[0]] * 2
if isinstance(self.kernel_size, int):
self.kernel_size = [self.kernel_size] * 2
elif len(self.kernel_size) == 1:
self.kernel_size = [self.kernel_size[0]] * 2
def forward(self, x):
h, w = x.shape[2], x.shape[3]
s1, s0 = self.stride[1], self.stride[0]
k1, k0 = self.kernel_size[1], self.kernel_size[0]
extra_h = (self.divide_ceil(w, s1) - 1) * s1 - w + k1
extra_v = (self.divide_ceil(h, s0) - 1) * s0 - h + k0
left = math.floor(torch.div(extra_h, 2))
right = extra_h - left
top = math.floor(torch.div(extra_v, 2))
bottom = extra_v - top
x = F.pad(x, [left, right, top, bottom])
x = self.conv(x)
return x
def divide_ceil(self, a: int, b: int) -> int:
return -(-a // b)
class SeparableConvBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
super(SeparableConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
kernel_size=3, stride=1, groups=in_channels, bias=False)
self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
self.norm = norm
# Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
self.activation = activation
def forward(self, x):
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
if self.norm:
x = self.bn(x)
return x
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TestBrokenStuff(nn.Module):
def __init__(self):
super().__init__()
self.epsilon = 1e-4
self.p6_w1 = nn.Parameter(torch.ones(2), requires_grad=True)
self.p6_w1_relu = nn.ReLU()
self.p5_w1 = nn.Parameter(torch.ones(2), requires_grad=True)
self.p5_w1_relu = nn.ReLU()
self.conv5_up = SeparableConvBlock(64, onnx_export=True)
self.conv6_up = SeparableConvBlock(64, onnx_export=True)
self.swish = Swish()
self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, p5_in, p6_in, p7_in):
# Weights for P6_0 and P7_0 to P6_1
p6_w1 = self.p6_w1_relu(self.p6_w1)
weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
# Weights for P5_0 and P6_1 to P5_1
p5_w1 = self.p5_w1_relu(self.p5_w1)
weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
# Connections for P5_0 and P6_1 to P5_1 respectively
weight_p5_in = weight[0] * p5_in
weight_p5_upsample = weight[1] * self.p5_upsample(p6_up)
weight_swish_add = weight_p5_in + weight_p5_upsample
weight_swish = self.swish(weight_swish_add)
p5_up = self.conv5_up(weight_swish)
return p6_up, p5_up
model = TestBrokenStuff()
dummy_input = (torch.ones(1, 64, 16, 16), torch.ones(1, 64, 8, 8), torch.ones(1, 64, 4, 4))
dummy_out = model(*dummy_input)
print("nonscripted", [out.shape for out in dummy_out])
scripted_model = torch.jit.script(model)
dummy_out = scripted_model(*dummy_input)
print("scripted", [out.shape for out in dummy_out])
trt_model = torch_tensorrt.compile(scripted_model,
inputs= [torch_tensorrt.Input(dummy_input[0].shape),
torch_tensorrt.Input(dummy_input[1].shape),
torch_tensorrt.Input(dummy_input[2].shape)],
)
outputs
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
torch_executed_modules: [
]
GRAPH: [Torch-TensorRT] - Before lowering: graph(%self : __torch__.___torch_mangle_1638.TestBrokenStuff,
%p5_in.1 : Tensor,
%p6_in.1 : Tensor,
%p7_in.1 : Tensor):
%12 : NoneType = prim::Constant()
%11 : bool = prim::Constant[value=0]()
%9 : int = prim::Constant[value=0]() # /tmp/ipykernel_8946/3523899367.py:101:47
%26 : int = prim::Constant[value=1]() # /tmp/ipykernel_8946/3523899367.py:103:68
%p6_w1_relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="p6_w1_relu"](%self)
%p6_w1.1 : Tensor = prim::GetAttr[name="p6_w1"](%self)
%p6_w1.3 : Tensor = prim::CallMethod[name="forward"](%p6_w1_relu, %p6_w1.1) # /tmp/ipykernel_8946/3523899367.py:100:16
%10 : int[] = prim::ListConstruct(%9)
%13 : Tensor = aten::sum(%p6_w1.3, %10, %11, %12) # /tmp/ipykernel_8946/3523899367.py:101:26
%epsilon.1 : float = prim::GetAttr[name="epsilon"](%self)
%16 : Tensor = aten::add(%13, %epsilon.1, %26) # /tmp/ipykernel_8946/3523899367.py:101:26
%weight.1 : Tensor = aten::div(%p6_w1.3, %16) # /tmp/ipykernel_8946/3523899367.py:101:17
%conv6_up : __torch__.___torch_mangle_1636.SeparableConvBlock = prim::GetAttr[name="conv6_up"](%self)
%swish.1 : __torch__.___torch_mangle_1637.Swish = prim::GetAttr[name="swish"](%self)
%22 : Tensor = aten::select(%weight.1, %9, %9) # /tmp/ipykernel_8946/3523899367.py:103:41
%24 : Tensor = aten::mul(%22, %p6_in.1) # /tmp/ipykernel_8946/3523899367.py:103:41
%28 : Tensor = aten::select(%weight.1, %9, %26) # /tmp/ipykernel_8946/3523899367.py:103:61
%p6_upsample : __torch__.torch.nn.modules.upsampling.Upsample = prim::GetAttr[name="p6_upsample"](%self)
%31 : Tensor = prim::CallMethod[name="forward"](%p6_upsample, %p7_in.1) # /tmp/ipykernel_8946/3523899367.py:103:73
%32 : Tensor = aten::mul(%28, %31) # /tmp/ipykernel_8946/3523899367.py:103:61
%34 : Tensor = aten::add(%24, %32, %26) # /tmp/ipykernel_8946/3523899367.py:103:41
%35 : Tensor = prim::CallMethod[name="forward"](%swish.1, %34) # /tmp/ipykernel_8946/3523899367.py:103:30
%p6_up.1 : Tensor = prim::CallMethod[name="forward"](%conv6_up, %35) # /tmp/ipykernel_8946/3523899367.py:103:16
%p5_w1_relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="p5_w1_relu"](%self)
%p5_w1.1 : Tensor = prim::GetAttr[name="p5_w1"](%self)
%p5_w1.3 : Tensor = prim::CallMethod[name="forward"](%p5_w1_relu, %p5_w1.1) # /tmp/ipykernel_8946/3523899367.py:106:16
%42 : int[] = prim::ListConstruct(%9)
%45 : Tensor = aten::sum(%p5_w1.3, %42, %11, %12) # /tmp/ipykernel_8946/3523899367.py:107:26
%epsilon : float = prim::GetAttr[name="epsilon"](%self)
%48 : Tensor = aten::add(%45, %epsilon, %26) # /tmp/ipykernel_8946/3523899367.py:107:26
%weigh
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [18], in <module>
122 dummy_out = scripted_model(*dummy_input)
123 print("scripted", [out.shape for out in dummy_out])
--> 125 trt_model = torch_tensorrt.compile(scripted_model,
126 inputs= [torch_tensorrt.Input(dummy_input[0].shape),
127 torch_tensorrt.Input(dummy_input[1].shape),
128 torch_tensorrt.Input(dummy_input[2].shape)],
129 )
File /opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py:97, in compile(module, ir, inputs, enabled_precisions, **kwargs)
92 logging.log(
93 logging.Level.Info,
94 "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
95 )
96 ts_mod = torch.jit.script(module)
---> 97 return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
98 elif target_ir == _IRType.fx:
99 raise RuntimeError("fx is currently not supported")
File /opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py:119, in compile(module, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, strict_types, capability, num_min_timing_iters, num_avg_timing_iters, workspace_size, max_batch_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
91 raise ValueError(
92 "require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: "
93 + torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules)
95 spec = {
96 "inputs": inputs,
97 "device": device,
(...)
116 }
117 }
--> 119 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
120 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
121 return compiled_module
RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.
t.7 : Tensor = aten::div(%p5_w1.3, %48) # /tmp/ipykernel_8946/3523899367.py:107:17
%52 : Tensor = aten::select(%weight.7, %9, %9) # /tmp/ipykernel_8946/3523899367.py:109:23
%weight_p5_in.1 : Tensor = aten::mul(%52, %p5_in.1) # /tmp/ipykernel_8946/3523899367.py:109:23
%57 : Tensor = aten::select(%weight.7, %9, %26) # /tmp/ipykernel_8946/3523899367.py:110:29
%p5_upsample : __torch__.torch.nn.modules.upsampling.Upsample = prim::GetAttr[name="p5_upsample"](%self)
%60 : Tensor = prim::CallMethod[name="forward"](%p5_upsample, %p6_up.1) # /tmp/ipykernel_8946/3523899367.py:110:41
%weight_p5_upsample.1 : Tensor = aten::mul(%57, %60) # /tmp/ipykernel_8946/3523899367.py:110:29
%weight_swish_add.1 : Tensor = aten::add(%weight_p5_in.1, %weight_p5_upsample.1, %26) # /tmp/ipykernel_8946/3523899367.py:111:27
%swish : __torch__.___torch_mangle_1637.Swish = prim::GetAttr[name="swish"](%self)
%weight_swish.1 : Tensor = prim::CallMethod[name="forward"](%swish, %weight_swish_add.1) # /tmp/ipykernel_8946/3523899367.py:112:23
%conv5_up : __torch__.___torch_mangle_1636.SeparableConvBlock = prim::GetAttr[name="conv5_up"](%self)
%p5_up.1 : Tensor = prim::CallMethod[name="forward"](%conv5_up, %weight_swish.1) # /tmp/ipykernel_8946/3523899367.py:113:16
%74 : (Tensor, Tensor) = prim::TupleConstruct(%p6_up.1, %p5_up.1)
return (%74)
GRAPH: [Torch-TensorRT] - After freeze: graph(%self : __torch__.___torch_mangle_1639.TestBrokenStuff,
%p5_in.1 : Tensor,
%p6_in.1 : Tensor,
%p7_in.1 : Tensor):
%450 : int[] = prim::Constant[value=[0, 0]]()
%449 : int[] = prim::Constant[value=[1, 1]]()
%118 : int = prim::Constant[value=64]() # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:444:53
%117 : str = prim::Constant[value="AssertionError: Padding length too large"]()
%112 : float = prim::Constant[value=0.]()
%110 : str = prim::Constant[value="Expected more than 1 value per channel when training, got input size {}"]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2377:25
%109 : float = prim::Constant[value=0.01]()
%108 : float = prim::Constant[value=0.001]() # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:179:12
%64 : str = prim::Constant[value="nearest"]() # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/upsampling.py:153:66
%63 : float = prim::Constant[value=2.]() # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/upsampling.py:153:47
%61 : int = prim::Constant[value=5]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:22
%60 : int = prim::Constant[value=3]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:22
%59 : int = prim::Constant[value=4]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3869:76
%58 : int = prim::Constant[value=2]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3792:24
%50 : int = prim::Constant[value=9223372036854775807]()
%49 : str = prim::Constant[value="The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. "]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3840:24
%48 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3942:8
%442 : int[] = prim::Constant[value=[0]]()
%7 : int = prim::Constant[value=1]() # /tmp/ipykernel_8946/3523899367.py:103:68
%6 : int = prim::Constant[value=0]() # /tmp/ipykernel_8946/3523899367.py:101:47
%5 : bool = prim::Constant[value=0]()
%4 : NoneType = prim::Constant()
%self.conv6_up.norm : bool = prim::Constant[value=1]()
%self.epsilon : float = prim::Constant[value=0.0001]()
%p6_w1.1 : Tensor = prim::GetAttr[name="p6_w1"](%self)
%p6_w1.3 : Tensor = aten::relu(%p6_w1.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:1410:17
%12 : Tensor = aten::sum(%p6_w1.3, %442, %5, %4) # /tmp/ipykernel_8946/3523899367.py:101:26
%14 : Tensor = aten::add(%12, %self.epsilon, %7) # /tmp/ipykernel_8946/3523899367.py:101:26
%weight.2 : Tensor = aten::div(%p6_w1.3, %14) # /tmp/ipykernel_8946/3523899367.py:101:17
%conv6_up : __torch__.___torch_mangle_1641.SeparableConvBlock = prim::GetAttr[name="conv6_up"](%self)
%18 : Tensor = aten::select(%weight.2, %6, %6) # /tmp/ipykernel_8946/3523899367.py:103:41
%19 : Tensor = aten::mul(%18, %p6_in.1) # /tmp/ipykernel_8946/3523899367.py:103:41
%20 : Tensor = aten::select(%weight.2, %6, %7) # /tmp/ipykernel_8946/3523899367.py:103:61
%65 : Tensor = prim::Uninitialized()
%69 : int = aten::dim(%p7_in.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3792:10
%dim.2 : int = aten::sub(%69, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3792:10
%scale_factors.4 : float[] = prim::ListConstruct()
= prim::Loop(%dim.2, %self.conv6_up.norm) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3829:28
block0(%72 : int):
%73 : float[] = aten::append(%scale_factors.4, %63) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3829:28
-> (%self.conv6_up.norm)
%77 : int = aten::len(%scale_factors.4) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
%78 : bool = aten::gt(%77, %6)
%79 : int = prim::Loop(%50, %78, %6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
block0(%80 : int, %81 : int):
%scale.2 : float = aten::__getitem__(%scale_factors.4, %81) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
%83 : int = aten::floor(%scale.2) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:19
%84 : bool = aten::ne(%83, %scale.2) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:19
%85 : bool = prim::If(%84) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:16
block0():
= aten::warn[warn_id=5](%49, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3839:20
-> (%5)
block1():
-> (%self.conv6_up.norm)
%86 : int = aten::add(%81, %7)
%87 : bool = aten::lt(%86, %77)
%88 : bool = aten::__and__(%87, %85)
-> (%88, %86)
%89 : int = aten::dim(%p7_in.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:7
%90 : bool = aten::eq(%89, %60) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:7
%92 : Tensor = prim::If(%90) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:4
block0():
%93 : Tensor = aten::upsample_nearest1d(%p7_in.1, %4, %scale_factors.4) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3873:15
-> (%93)
block1():
%94 : int = aten::dim(%p7_in.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3874:7
%95 : bool = aten::eq(%94, %59) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3874:7
%97 : Tensor = prim::If(%95) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3874:4
block0():
%98 : Tensor = aten::upsample_nearest2d(%p7_in.1, %4, %scale_factors.4) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3875:15
-> (%98)
block1():
%99 : int = aten::dim(%p7_in.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:7
%100 : bool = aten::eq(%99, %61) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:7
%102 : Tensor = prim::If(%100) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:4
block0():
%103 : Tensor = aten::upsample_nearest3d(%p7_in.1, %4, %scale_factors.4) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3877:15
-> (%103)
block1():
%104 : int = aten::dim(%p7_in.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3944:27
%105 : str = aten::format(%48, %104, %64) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3942:8
= prim::RaiseException(%105) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3941:4
-> (%65)
-> (%102)
-> (%97)
%23 : Tensor = aten::mul(%20, %92) # /tmp/ipykernel_8946/3523899367.py:103:61
%24 : Tensor = aten::add(%19, %23, %7) # /tmp/ipykernel_8946/3523899367.py:103:41
%106 : Tensor = aten::sigmoid(%24) # /tmp/ipykernel_8946/3523899367.py:80:19
%107 : Tensor = aten::mul(%24, %106) # /tmp/ipykernel_8946/3523899367.py:80:15
%depthwise_conv.1 : __torch__.___torch_mangle_1642.Conv2dStaticSamePadding = prim::GetAttr[name="depthwise_conv"](%conv6_up)
%120 : int[] = aten::size(%107) # <string>:13:9
%h.3 : int = aten::__getitem__(%120, %58) # /tmp/ipykernel_8946/3523899367.py:33:15
%w.3 : int = aten::__getitem__(%120, %60) # /tmp/ipykernel_8946/3523899367.py:33:27
%132 : int = aten::neg(%w.3) # /tmp/ipykernel_8946/3523899367.py:51:17
%133 : int = aten::floordiv(%132, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%134 : int = aten::neg(%133) # /tmp/ipykernel_8946/3523899367.py:51:15
%135 : int = aten::sub(%134, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%136 : int = aten::mul(%135, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%137 : int = aten::sub(%136, %w.3) # /tmp/ipykernel_8946/3523899367.py:37:19
%extra_h.3 : int = aten::add(%137, %60) # /tmp/ipykernel_8946/3523899367.py:37:19
%139 : int = aten::neg(%h.3) # /tmp/ipykernel_8946/3523899367.py:51:17
%140 : int = aten::floordiv(%139, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%141 : int = aten::neg(%140) # /tmp/ipykernel_8946/3523899367.py:51:15
%142 : int = aten::sub(%141, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%143 : int = aten::mul(%142, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%144 : int = aten::sub(%143, %h.3) # /tmp/ipykernel_8946/3523899367.py:38:19
%extra_v.3 : int = aten::add(%144, %60) # /tmp/ipykernel_8946/3523899367.py:38:19
%146 : float = aten::div(%extra_h.3, %58) # /tmp/ipykernel_8946/3523899367.py:40:26
%left.3 : int = aten::floor(%146) # /tmp/ipykernel_8946/3523899367.py:40:15
%right.3 : int = aten::sub(%extra_h.3, %left.3) # /tmp/ipykernel_8946/3523899367.py:41:16
%149 : float = aten::div(%extra_v.3, %58) # /tmp/ipykernel_8946/3523899367.py:42:25
%top.3 : int = aten::floor(%149) # /tmp/ipykernel_8946/3523899367.py:42:14
%bottom.3 : int = aten::sub(%extra_v.3, %top.3) # /tmp/ipykernel_8946/3523899367.py:43:17
%152 : int[] = prim::ListConstruct(%left.3, %right.3, %top.3, %bottom.3)
%153 : int = aten::dim(%107) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:28
%154 : bool = aten::le(%58, %153) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:11
= prim::If(%154) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:4
block0():
-> ()
block1():
= prim::RaiseException(%117) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:4
-> ()
%155 : Tensor = aten::constant_pad_nd(%107, %152, %112) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4363:15
%conv.2 : __torch__.torch.nn.modules.conv.___torch_mangle_1643.Conv2d = prim::GetAttr[name="conv"](%depthwise_conv.1)
%weight.4 : Tensor = prim::GetAttr[name="weight"](%conv.2)
%bias.2 : Tensor? = prim::GetAttr[name="bias"](%conv.2)
%x.6 : Tensor = aten::conv2d(%155, %weight.4, %bias.2, %449, %450, %449, %118) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:443:15
%pointwise_conv.1 : __torch__.___torch_mangle_1644.Conv2dStaticSamePadding = prim::GetAttr[name="pointwise_conv"](%conv6_up)
%164 : int[] = aten::size(%x.6) # <string>:13:9
%h.5 : int = aten::__getitem__(%164, %58) # /tmp/ipykernel_8946/3523899367.py:33:15
%w.5 : int = aten::__getitem__(%164, %60) # /tmp/ipykernel_8946/3523899367.py:33:27
%176 : int = aten::neg(%w.5) # /tmp/ipykernel_8946/3523899367.py:51:17
%177 : int = aten::floordiv(%176, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%178 : int = aten::neg(%177) # /tmp/ipykernel_8946/3523899367.py:51:15
%179 : int = aten::sub(%178, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%180 : int = aten::mul(%179, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%181 : int = aten::sub(%180, %w.5) # /tmp/ipykernel_8946/3523899367.py:37:19
%extra_h.5 : int = aten::add(%181, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%183 : int = aten::neg(%h.5) # /tmp/ipykernel_8946/3523899367.py:51:17
%184 : int = aten::floordiv(%183, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%185 : int = aten::neg(%184) # /tmp/ipykernel_8946/3523899367.py:51:15
%186 : int = aten::sub(%185, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%187 : int = aten::mul(%186, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%188 : int = aten::sub(%187, %h.5) # /tmp/ipykernel_8946/3523899367.py:38:19
%extra_v.5 : int = aten::add(%188, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%190 : float = aten::div(%extra_h.5, %58) # /tmp/ipykernel_8946/3523899367.py:40:26
%left.5 : int = aten::floor(%190) # /tmp/ipykernel_8946/3523899367.py:40:15
%right.5 : int = aten::sub(%extra_h.5, %left.5) # /tmp/ipykernel_8946/3523899367.py:41:16
%193 : float = aten::div(%extra_v.5, %58) # /tmp/ipykernel_8946/3523899367.py:42:25
%top.5 : int = aten::floor(%193) # /tmp/ipykernel_8946/3523899367.py:42:14
%bottom.5 : int = aten::sub(%extra_v.5, %top.5) # /tmp/ipykernel_8946/3523899367.py:43:17
%196 : int[] = prim::ListConstruct(%left.5, %right.5, %top.5, %bottom.5)
%197 : Tensor = aten::constant_pad_nd(%x.6, %196, %112) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4363:15
%conv.4 : __torch__.torch.nn.modules.conv.___torch_mangle_1645.Conv2d = prim::GetAttr[name="conv"](%pointwise_conv.1)
%weight.6 : Tensor = prim::GetAttr[name="weight"](%conv.4)
%bias.4 : Tensor? = prim::GetAttr[name="bias"](%conv.4)
%x.10 : Tensor = aten::conv2d(%197, %weight.6, %bias.4, %449, %450, %449, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:443:15
%bn.1 : __torch__.torch.nn.modules.batchnorm.___torch_mangle_1646.BatchNorm2d = prim::GetAttr[name="bn"](%conv6_up)
%training.2 : bool = prim::GetAttr[name="training"](%bn.1)
= prim::If(%training.2) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:145:11
block0():
%num_batches_tracked.8 : Tensor = prim::GetAttr[name="num_batches_tracked"](%bn.1)
%210 : Tensor = aten::add_(%num_batches_tracked.8, %7, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:148:16
-> ()
block1():
-> ()
%running_mean.1 : Tensor = prim::GetAttr[name="running_mean"](%bn.1)
%running_var.1 : Tensor = prim::GetAttr[name="running_var"](%bn.1)
%weight.8 : Tensor = prim::GetAttr[name="weight"](%bn.1)
%bias.6 : Tensor = prim::GetAttr[name="bias"](%bn.1)
= prim::If(%training.2) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2408:4
block0():
%216 : int[] = aten::size(%x.10) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2409:27
%size_prods.2 : int = aten::__getitem__(%216, %6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2373:17
%218 : int = aten::len(%216) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:19
%219 : int = aten::sub(%218, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:19
%size_prods.4 : int = prim::Loop(%219, %self.conv6_up.norm, %size_prods.2) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:4
block0(%i.2 : int, %size_prods.12 : int):
%223 : int = aten::add(%i.2, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:27
%224 : int = aten::__getitem__(%216, %223) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:22
%size_prods.14 : int = aten::mul(%size_prods.12, %224) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:8
-> (%self.conv6_up.norm, %size_prods.14)
%226 : bool = aten::eq(%size_prods.4, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2376:7
= prim::If(%226) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2376:4
block0():
%227 : str = aten::format(%110, %216) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2377:25
= prim::RaiseException(%227) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2377:8
-> ()
block1():
-> ()
-> ()
block1():
-> ()
%x.15 : Tensor = aten::batch_norm(%x.10, %weight.8, %bias.6, %running_mean.1, %running_var.1, %training.2, %109, %108, %self.conv6_up.norm) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2411:11
%p5_w1.1 : Tensor = prim::GetAttr[name="p5_w1"](%self)
%p5_w1.3 : Tensor = aten::relu(%p5_w1.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:1410:17
%31 : Tensor = aten::sum(%p5_w1.3, %442, %5, %4) # /tmp/ipykernel_8946/3523899367.py:107:26
%33 : Tensor = aten::add(%31, %self.epsilon, %7) # /tmp/ipykernel_8946/3523899367.py:107:26
%weight.7 : Tensor = aten::div(%p5_w1.3, %33) # /tmp/ipykernel_8946/3523899367.py:107:17
%35 : Tensor = aten::select(%weight.7, %6, %6) # /tmp/ipykernel_8946/3523899367.py:109:23
%weight_p5_in.1 : Tensor = aten::mul(%35, %p5_in.1) # /tmp/ipykernel_8946/3523899367.py:109:23
%37 : Tensor = aten::select(%weight.7, %6, %7) # /tmp/ipykernel_8946/3523899367.py:110:29
%251 : int = aten::dim(%x.15) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3792:10
%dim.1 : int = aten::sub(%251, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3792:10
%scale_factors.3 : float[] = prim::ListConstruct()
= prim::Loop(%dim.1, %self.conv6_up.norm) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3829:28
block0(%254 : int):
%255 : float[] = aten::append(%scale_factors.3, %63) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3829:28
-> (%self.conv6_up.norm)
%259 : int = aten::len(%scale_factors.3) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
%260 : bool = aten::gt(%259, %6)
%261 : int = prim::Loop(%50, %260, %6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
block0(%262 : int, %263 : int):
%scale.1 : float = aten::__getitem__(%scale_factors.3, %263) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3837:12
%265 : int = aten::floor(%scale.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:19
%266 : bool = aten::ne(%265, %scale.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:19
%267 : bool = prim::If(%266) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3838:16
block0():
= aten::warn[warn_id=5](%49, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3839:20
-> (%5)
block1():
-> (%self.conv6_up.norm)
%268 : int = aten::add(%263, %7)
%269 : bool = aten::lt(%268, %259)
%270 : bool = aten::__and__(%269, %267)
-> (%270, %268)
%272 : bool = aten::eq(%251, %60) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:7
%274 : Tensor = prim::If(%272) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3872:4
block0():
%275 : Tensor = aten::upsample_nearest1d(%x.15, %4, %scale_factors.3) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3873:15
-> (%275)
block1():
%277 : bool = aten::eq(%251, %59) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3874:7
%279 : Tensor = prim::If(%277) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3874:4
block0():
%280 : Tensor = aten::upsample_nearest2d(%x.15, %4, %scale_factors.3) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3875:15
-> (%280)
block1():
%282 : bool = aten::eq(%251, %61) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:7
%284 : Tensor = prim::If(%282) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3876:4
block0():
%285 : Tensor = aten::upsample_nearest3d(%x.15, %4, %scale_factors.3) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3877:15
-> (%285)
block1():
%287 : str = aten::format(%48, %251, %64) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3942:8
= prim::RaiseException(%287) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3941:4
-> (%65)
-> (%284)
-> (%279)
%weight_p5_upsample.1 : Tensor = aten::mul(%37, %274) # /tmp/ipykernel_8946/3523899367.py:110:29
%weight_swish_add.1 : Tensor = aten::add(%weight_p5_in.1, %weight_p5_upsample.1, %7) # /tmp/ipykernel_8946/3523899367.py:111:27
%288 : Tensor = aten::sigmoid(%weight_swish_add.1) # /tmp/ipykernel_8946/3523899367.py:80:19
%weight_swish.1 : Tensor = aten::mul(%weight_swish_add.1, %288) # /tmp/ipykernel_8946/3523899367.py:80:15
%conv5_up : __torch__.___torch_mangle_1641.SeparableConvBlock = prim::GetAttr[name="conv5_up"](%self)
%depthwise_conv : __torch__.___torch_mangle_1642.Conv2dStaticSamePadding = prim::GetAttr[name="depthwise_conv"](%conv5_up)
%302 : int[] = aten::size(%weight_swish.1) # <string>:13:9
%h.2 : int = aten::__getitem__(%302, %58) # /tmp/ipykernel_8946/3523899367.py:33:15
%w.2 : int = aten::__getitem__(%302, %60) # /tmp/ipykernel_8946/3523899367.py:33:27
%314 : int = aten::neg(%w.2) # /tmp/ipykernel_8946/3523899367.py:51:17
%315 : int = aten::floordiv(%314, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%316 : int = aten::neg(%315) # /tmp/ipykernel_8946/3523899367.py:51:15
%317 : int = aten::sub(%316, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%318 : int = aten::mul(%317, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%319 : int = aten::sub(%318, %w.2) # /tmp/ipykernel_8946/3523899367.py:37:19
%extra_h.2 : int = aten::add(%319, %60) # /tmp/ipykernel_8946/3523899367.py:37:19
%321 : int = aten::neg(%h.2) # /tmp/ipykernel_8946/3523899367.py:51:17
%322 : int = aten::floordiv(%321, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%323 : int = aten::neg(%322) # /tmp/ipykernel_8946/3523899367.py:51:15
%324 : int = aten::sub(%323, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%325 : int = aten::mul(%324, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%326 : int = aten::sub(%325, %h.2) # /tmp/ipykernel_8946/3523899367.py:38:19
%extra_v.2 : int = aten::add(%326, %60) # /tmp/ipykernel_8946/3523899367.py:38:19
%328 : float = aten::div(%extra_h.2, %58) # /tmp/ipykernel_8946/3523899367.py:40:26
%left.2 : int = aten::floor(%328) # /tmp/ipykernel_8946/3523899367.py:40:15
%right.2 : int = aten::sub(%extra_h.2, %left.2) # /tmp/ipykernel_8946/3523899367.py:41:16
%331 : float = aten::div(%extra_v.2, %58) # /tmp/ipykernel_8946/3523899367.py:42:25
%top.2 : int = aten::floor(%331) # /tmp/ipykernel_8946/3523899367.py:42:14
%bottom.2 : int = aten::sub(%extra_v.2, %top.2) # /tmp/ipykernel_8946/3523899367.py:43:17
%334 : int[] = prim::ListConstruct(%left.2, %right.2, %top.2, %bottom.2)
%335 : int = aten::dim(%weight_swish.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:28
%336 : bool = aten::le(%58, %335) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:11
= prim::If(%336) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:4
block0():
-> ()
block1():
= prim::RaiseException(%117) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4361:4
-> ()
%337 : Tensor = aten::constant_pad_nd(%weight_swish.1, %334, %112) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4363:15
%conv.1 : __torch__.torch.nn.modules.conv.___torch_mangle_1643.Conv2d = prim::GetAttr[name="conv"](%depthwise_conv)
%weight.1 : Tensor = prim::GetAttr[name="weight"](%conv.1)
%bias.1 : Tensor? = prim::GetAttr[name="bias"](%conv.1)
%x.5 : Tensor = aten::conv2d(%337, %weight.1, %bias.1, %449, %450, %449, %118) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:443:15
%pointwise_conv : __torch__.___torch_mangle_1644.Conv2dStaticSamePadding = prim::GetAttr[name="pointwise_conv"](%conv5_up)
%346 : int[] = aten::size(%x.5) # <string>:13:9
%h.1 : int = aten::__getitem__(%346, %58) # /tmp/ipykernel_8946/3523899367.py:33:15
%w.1 : int = aten::__getitem__(%346, %60) # /tmp/ipykernel_8946/3523899367.py:33:27
%358 : int = aten::neg(%w.1) # /tmp/ipykernel_8946/3523899367.py:51:17
%359 : int = aten::floordiv(%358, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%360 : int = aten::neg(%359) # /tmp/ipykernel_8946/3523899367.py:51:15
%361 : int = aten::sub(%360, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%362 : int = aten::mul(%361, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%363 : int = aten::sub(%362, %w.1) # /tmp/ipykernel_8946/3523899367.py:37:19
%extra_h.1 : int = aten::add(%363, %7) # /tmp/ipykernel_8946/3523899367.py:37:19
%365 : int = aten::neg(%h.1) # /tmp/ipykernel_8946/3523899367.py:51:17
%366 : int = aten::floordiv(%365, %7) # /tmp/ipykernel_8946/3523899367.py:51:17
%367 : int = aten::neg(%366) # /tmp/ipykernel_8946/3523899367.py:51:15
%368 : int = aten::sub(%367, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%369 : int = aten::mul(%368, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%370 : int = aten::sub(%369, %h.1) # /tmp/ipykernel_8946/3523899367.py:38:19
%extra_v.1 : int = aten::add(%370, %7) # /tmp/ipykernel_8946/3523899367.py:38:19
%372 : float = aten::div(%extra_h.1, %58) # /tmp/ipykernel_8946/3523899367.py:40:26
%left.1 : int = aten::floor(%372) # /tmp/ipykernel_8946/3523899367.py:40:15
%right.1 : int = aten::sub(%extra_h.1, %left.1) # /tmp/ipykernel_8946/3523899367.py:41:16
%375 : float = aten::div(%extra_v.1, %58) # /tmp/ipykernel_8946/3523899367.py:42:25
%top.1 : int = aten::floor(%375) # /tmp/ipykernel_8946/3523899367.py:42:14
%bottom.1 : int = aten::sub(%extra_v.1, %top.1) # /tmp/ipykernel_8946/3523899367.py:43:17
%378 : int[] = prim::ListConstruct(%left.1, %right.1, %top.1, %bottom.1)
%379 : Tensor = aten::constant_pad_nd(%x.5, %378, %112) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:4363:15
%conv : __torch__.torch.nn.modules.conv.___torch_mangle_1645.Conv2d = prim::GetAttr[name="conv"](%pointwise_conv)
%weight.3 : Tensor = prim::GetAttr[name="weight"](%conv)
%bias.3 : Tensor? = prim::GetAttr[name="bias"](%conv)
%x.9 : Tensor = aten::conv2d(%379, %weight.3, %bias.3, %449, %450, %449, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py:443:15
%bn : __torch__.torch.nn.modules.batchnorm.___torch_mangle_1646.BatchNorm2d = prim::GetAttr[name="bn"](%conv5_up)
%training.1 : bool = prim::GetAttr[name="training"](%bn)
= prim::If(%training.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:145:11
block0():
%num_batches_tracked.7 : Tensor = prim::GetAttr[name="num_batches_tracked"](%bn)
%392 : Tensor = aten::add_(%num_batches_tracked.7, %7, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:148:16
-> ()
block1():
-> ()
%running_mean : Tensor = prim::GetAttr[name="running_mean"](%bn)
%running_var : Tensor = prim::GetAttr[name="running_var"](%bn)
%weight : Tensor = prim::GetAttr[name="weight"](%bn)
%bias : Tensor = prim::GetAttr[name="bias"](%bn)
= prim::If(%training.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2408:4
block0():
%398 : int[] = aten::size(%x.9) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2409:27
%size_prods.1 : int = aten::__getitem__(%398, %6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2373:17
%400 : int = aten::len(%398) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:19
%401 : int = aten::sub(%400, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:19
%size_prods : int = prim::Loop(%401, %self.conv6_up.norm, %size_prods.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2374:4
block0(%i.1 : int, %size_prods.11 : int):
%405 : int = aten::add(%i.1, %58) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:27
%406 : int = aten::__getitem__(%398, %405) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:22
%size_prods.5 : int = aten::mul(%size_prods.11, %406) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2375:8
-> (%self.conv6_up.norm, %size_prods.5)
%408 : bool = aten::eq(%size_prods, %7) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2376:7
= prim::If(%408) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2376:4
block0():
%409 : str = aten::format(%110, %398) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2377:25
= prim::RaiseException(%409) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2377:8
-> ()
block1():
-> ()
-> ()
block1():
-> ()
%x.14 : Tensor = aten::batch_norm(%x.9, %weight, %bias, %running_mean, %running_var, %training.1, %109, %108, %self.conv6_up.norm) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2411:11
%46 : (Tensor, Tensor) = prim::TupleConstruct(%x.15, %x.14)
return (%46)
GRAPH: [Torch-TensorRT] - LibTorch Lowering
Edit: added torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
.
Apologies, I forgot model.eval()
. After doing that, it just hangs...
Correction, it has a segmentation fault. See attached logs. Code is in temp.py.txt.
I think I discovered the issue. It seems that using nn.Upsample causes a segmentation fault for unknown reasons. It also looks like F.upsample has the same issue.
@narendasan any ideas? Perhaps this could be related to my original problem, though it doesn't make sense why my first error (not comformable) wasn't also a segfault.
I was able to get the debug logs you're looking for, and the error seems to be an off-by-one issue (see attached for full log).
DEBUG: [Torch-TensorRT] - ITensor shape: [1, 64, 16, 249]
DEBUG: [Torch-TensorRT] - ITensor type: Float32
DEBUG: [Torch-TensorRT] - ITensor shape: [1, 64, 16, 248]
DEBUG: [Torch-TensorRT] - ITensor type: Float32
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1285] Error Code 4: Miscellaneous (IElementWiseLayer %32975 : Tensor = aten::add(%32950, %32974, %self.backbone_net.model._blocks.0.expand_ratio.1) # /home/user/models/efficientdet/bifpn.py:180:41: broadcast dimensions must be conformable)
DEBUG: [Torch-TensorRT] - Output tensor shape: []
The conversion to TensorRT must break because something is being evaluated differently than torchscript...
Hmm yeah seems like the 1, 64, 16, 249
is derived from a series of batch norms and 1, 64, 16, 248
from upsample nearest, that definitely narrows down where to look.
Yeah, trying to find a minimal reproducible script now. Is there a way to make the conversion fall back to PyTorch for these kind of issues, or is fixing this issue, for example, part of the Torch-TensorRT conversion process?
yes. You can explicitly specify operators to run in PyTorch using torch_executed_ops
I've been able to convert the most problematic part of my network (so far), but when performing inference it gives me a segmentation fault. @narendasan can you help with this at all? I'm now working on the PyTorch -> ONNX -> TensorRT route, because ONNX isn't giving me the same issues that Torch_TensorRT is.
Hmm, I will take another pass on this once we wrap up this release cycle. I can replicate the segfault and it seems to be coming from PyTorch. From what I can tell the tensors we are getting passed in the engine execution op are malformed, so when we query for things from PyTorch like current device and data type we get null pointer dereference's within torch::Tensor.
This is odd since our inference testing is passing which presumably does the same set of operations.
@dav-ell I am encountering the same RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.
error that is in your paste above, and at the same compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
line. Do you have any clue what was causing this problem for you, and how you were able to fix it? I saw your note on model.eval()
but this had no effect for me.
I am on torch_tensorrt 1.1.0, torch 1.11.0+cu115. Both torchtrtc
at the commandline and torch_tensorrt.compile()
(from the original module, not the Torchscript one) produced this.
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
Naren needs to ask user to retest with latest codebase.
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
❓ Question
I'm converting my (EfficientDet) model using
torch_tensorrt.compile
like so, after having successfully torchscript'ed it and verified that the output is the same as the non-torchscripted version:When running the above, I'm getting this error:
The offending line is:
where
weight
is a two-element tensor.What you have already tried
When broken up, the error happens more specifically with the add:
The shapes of both
weight_p5_in
andweight_p5_upsample
are[1, 64, 16, 16]
, so it can't be the shapes that are the problem. Furthermore, the line before it is very similar and has no error:I took a look at the values and didn't see anything fishy. I'm pretty much at a loss of what else to look at.
Environment
conda
,pip
,libtorch
, source): NGC PyTorch container, nvcr.io/nvidia/pytorch:22.02-py3Additional context
I was under the impression that any model that is
torch.jit.script()
able is able to be converted using Torch-TensorRT. That has not been the case for me, this is actually the second problem I'm facing when converting mytorch.jit.script()
ed model after having spent a while fixing the first issue (dealing with math.ceil() not being supported). Could I perhaps be using the wrong API?