pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.17k stars 22.09k forks source link

Deadcode emitted when using `jit.script` with `onnx.export`. #130366

Open stswidwinski opened 2 months ago

stswidwinski commented 2 months ago

🐛 Describe the bug

Describe the bug

When using torch.jit.script in combination with torch.onnx.export the resulting onnx graph contains dead code and loops which seem entirely redundant.

Repro

import torch

# Sample CNN to serialize
class CNN(torch.nn.Module):
    def __init__(self):   
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),         
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
        )
        self.fc = torch.nn.Linear(64, 10)

    def forward(self, x):                                                                                
        x = self.conv1(x)      
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv_layers(x)
        x = x.mean(dim=[2, 3])                
        x = self.fc(x)        
        return (x,)

# Common variables to use when performing serialization                        
sample_input = {'x': torch.randn(1, 1, 49, 64)}
input_names = ['x']   
dynamic_axes = { 'x': [0]}
model = CNN()

# Serialize the model to onnx via scripting
torch.onnx.export(
    torch.jit.script(model),
    args=(sample_input,),
    f="/tmp/scripted.onnx",
    training=torch.onnx.TrainingMode.TRAINING,
    do_constant_folding=False,
    input_names=input_names,
    dynamic_axes=dynamic_axes
)

# Repeat the process, but this time trace it. The two outcomes should be 
# nearly the same in structure.
torch.onnx.export(
    model,
    args=(sample_input,),
    f="/tmp/traced.onnx",
    training=torch.onnx.TrainingMode.TRAINING,
    do_constant_folding=False,
    input_names=input_names,
    dynamic_axes=dynamic_axes
)

Given that there are no dynamic control flow elements in the graph (i.e. no if clauses, no for loop etc.), I expect the graphs to be similar when scripting and tracing, at least within reason. However, this is not the case. The scripted graph contains for loops whose output is not used and body does not seem to be performing any in-place operations. These loops seem to be completely redundant:

Screenshot 2024-07-09 at 3 32 03 PM

When tracing, these loops are not seen as expected:

Screenshot 2024-07-09 at 3 32 13 PM

Expected behavior.

I expect the resulting serialization to not contain deadcode.

Versions

Versions

Collecting environment information...
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.3
Libc version: N/A

Python version: 3.11.8 (main, May 31 2024, 14:37:05) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxscript==0.1.0.dev20240605
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[conda] Could not collect
justinchuby commented 1 month ago

Is it possible to use jit.trace? Could be due to how scripting is capturing the graph. Fixing this is low priority but contributions are welcomed.