When using torch.jit.script in combination with torch.onnx.export the resulting onnx graph contains dead code and loops which seem entirely redundant.
Repro
import torch
# Sample CNN to serialize
class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = torch.nn.BatchNorm2d(64)
self.relu = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv_layers = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
torch.nn.BatchNorm2d(64),
)
self.fc = torch.nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv_layers(x)
x = x.mean(dim=[2, 3])
x = self.fc(x)
return (x,)
# Common variables to use when performing serialization
sample_input = {'x': torch.randn(1, 1, 49, 64)}
input_names = ['x']
dynamic_axes = { 'x': [0]}
model = CNN()
# Serialize the model to onnx via scripting
torch.onnx.export(
torch.jit.script(model),
args=(sample_input,),
f="/tmp/scripted.onnx",
training=torch.onnx.TrainingMode.TRAINING,
do_constant_folding=False,
input_names=input_names,
dynamic_axes=dynamic_axes
)
# Repeat the process, but this time trace it. The two outcomes should be
# nearly the same in structure.
torch.onnx.export(
model,
args=(sample_input,),
f="/tmp/traced.onnx",
training=torch.onnx.TrainingMode.TRAINING,
do_constant_folding=False,
input_names=input_names,
dynamic_axes=dynamic_axes
)
Given that there are no dynamic control flow elements in the graph (i.e. no if clauses, no for loop etc.), I expect the graphs to be similar when scripting and tracing, at least within reason. However, this is not the case. The scripted graph contains for loops whose output is not used and body does not seem to be performing any in-place operations. These loops seem to be completely redundant:
When tracing, these loops are not seen as expected:
Expected behavior.
I expect the resulting serialization to not contain deadcode.
Versions
Versions
Collecting environment information...
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.3
Libc version: N/A
Python version: 3.11.8 (main, May 31 2024, 14:37:05) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxscript==0.1.0.dev20240605
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[conda] Could not collect
🐛 Describe the bug
Describe the bug
When using
torch.jit.script
in combination withtorch.onnx.export
the resultingonnx
graph contains dead code and loops which seem entirely redundant.Repro
Given that there are no dynamic control flow elements in the graph (i.e. no
if
clauses, nofor
loop etc.), I expect the graphs to be similar when scripting and tracing, at least within reason. However, this is not the case. The scripted graph containsfor
loops whose output is not used and body does not seem to be performing any in-place operations. These loops seem to be completely redundant:When tracing, these loops are not seen as expected:
Expected behavior.
I expect the resulting serialization to not contain deadcode.
Versions
Versions